|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| from datasets import Dataset, load_dataset
|
| from transformers import AutoTokenizer
|
|
|
| from trl.experimental.utils import DataCollatorForChatML, truncate_dataset
|
|
|
| from ..testing_utils import TrlTestCase
|
|
|
|
|
| class TestDataCollatorForChatML(TrlTestCase):
|
| def setup_method(self):
|
|
|
| self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
|
| if self.tokenizer.pad_token is None:
|
| self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
|
|
|
| self.bos_token_id = self.tokenizer.bos_token_id if self.tokenizer.bos_token_id is not None else 1
|
| self.eos_token_id = self.tokenizer.eos_token_id if self.tokenizer.eos_token_id is not None else 2
|
|
|
| self.ignore_index = -100
|
| self.max_length = 1024
|
| self.messages_key = "messages"
|
|
|
|
|
| dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train")
|
| self.examples = dataset.to_list()
|
|
|
|
|
| self.collator = DataCollatorForChatML(
|
| tokenizer=self.tokenizer,
|
| max_length=self.max_length,
|
| ignore_index=self.ignore_index,
|
| )
|
|
|
| def test_data_collator_for_chatml(self):
|
|
|
| data = self.collator(self.examples)
|
|
|
|
|
| assert "input_ids" in data
|
| assert "attention_mask" in data
|
| assert "labels" in data
|
| assert "prompts" in data
|
| assert "prompt_attention_mask" in data
|
|
|
|
|
| input_ids = data["input_ids"][0].tolist()
|
| labels = data["labels"][0].tolist()
|
| prompt_only = data["prompts"][0].tolist()
|
|
|
|
|
| last_message = self.examples[0][self.messages_key][-1]
|
| assert last_message["role"] == "assistant", "Last message should be from assistant"
|
| last_assistant_response = last_message["content"]
|
|
|
|
|
| decoded_input = self.tokenizer.decode(input_ids)
|
| assert last_assistant_response in decoded_input, "Input should contain assistant's response"
|
|
|
|
|
| decoded_prompt = self.tokenizer.decode(prompt_only)
|
| assert last_assistant_response not in decoded_prompt, "Prompt should not contain assistant's response"
|
|
|
|
|
| prompt_length = len(prompt_only)
|
| assert all(label == self.ignore_index for label in labels[:prompt_length]), (
|
| "Labels should be ignore_index for prompt tokens"
|
| )
|
|
|
|
|
|
|
| last_assistant_response_with_end = last_assistant_response + self.tokenizer.eos_token
|
| last_assistant_response_tokens = self.tokenizer.encode(
|
| last_assistant_response_with_end, add_special_tokens=False
|
| )
|
|
|
| response_labels = []
|
| for label in labels[prompt_length:]:
|
| if label == self.ignore_index:
|
| continue
|
| response_labels.append(label)
|
| if label == self.tokenizer.convert_tokens_to_ids("<|im_end|>"):
|
| break
|
| assert response_labels == last_assistant_response_tokens, "Labels should match assistant response tokens"
|
|
|
|
|
| generation_prompt = "<|im_start|>assistant"
|
| assert not decoded_input.strip().endswith(generation_prompt), (
|
| f"Input should not end with generation prompt '{generation_prompt}'"
|
| )
|
|
|
| assert response_labels == last_assistant_response_tokens, "Labels should match assistant response tokens"
|
|
|
|
|
| class TestTruncateExamples(TrlTestCase):
|
| def test_with_dataset(self):
|
| examples = {
|
| "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
|
| "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
|
| }
|
| dataset = Dataset.from_dict(examples)
|
| dataset = dataset.with_format("numpy", dtype="float32")
|
| format = dataset.format
|
| max_length = 2
|
| expected_output = {
|
| "input_ids": [[1, 2], [4, 5], [8]],
|
| "attention_mask": [[0, 1], [0, 0], [1]],
|
| }
|
| dataset = truncate_dataset(dataset, max_length)
|
| assert dataset.to_dict() == expected_output
|
| assert format == dataset.format
|
|
|
| def test_with_iterable_dataset(self):
|
| examples = {
|
| "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
|
| "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
|
| }
|
| dataset = Dataset.from_dict(examples).to_iterable_dataset()
|
| dataset = dataset.with_format("numpy")
|
| formatting = dataset._formatting
|
| max_length = 2
|
| expected_output = {
|
| "input_ids": [[1, 2], [4, 5], [8]],
|
| "attention_mask": [[0, 1], [0, 0], [1]],
|
| }
|
| dataset = truncate_dataset(dataset, max_length)
|
| num_examples = len(examples[next(iter(examples))])
|
| assert next(iter(dataset.with_format(None).batch(batch_size=num_examples))) == expected_output
|
| assert formatting == dataset._formatting
|
|
|
| def test_with_extra_column(self):
|
| examples = {
|
| "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
|
| "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
|
| "my_column": ["a", "b", "c"],
|
| }
|
| dataset = Dataset.from_dict(examples)
|
| max_length = 2
|
| expected_output = {
|
| "input_ids": [[1, 2], [4, 5], [8]],
|
| "attention_mask": [[0, 1], [0, 0], [1]],
|
| "my_column": ["a", "b", "c"],
|
| }
|
| dataset = truncate_dataset(dataset, max_length)
|
| assert dataset.to_dict() == expected_output
|
|
|