|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import unittest |
|
|
|
|
|
import torch |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
from trl import DataCollatorForCompletionOnlyLM |
|
|
|
|
|
|
|
|
class DataCollatorForCompletionOnlyLMTester(unittest.TestCase): |
|
|
def test_data_collator_finds_response_template_llama2_tokenizer(self): |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab") |
|
|
self.instruction = """### System: You are a helpful assistant. |
|
|
|
|
|
### User: How much is 2+2? |
|
|
|
|
|
### Assistant: 2+2 equals 4""" |
|
|
self.instruction_template = "\n### User:" |
|
|
self.response_template = "\n### Assistant:" |
|
|
|
|
|
|
|
|
|
|
|
self.tokenized_instruction_w_context = self.tokenizer.encode( |
|
|
self.instruction_template, add_special_tokens=False |
|
|
)[2:] |
|
|
|
|
|
|
|
|
|
|
|
self.tokenized_response_w_context = self.tokenizer.encode(self.response_template, add_special_tokens=False)[2:] |
|
|
|
|
|
|
|
|
self.assertIn(self.response_template, self.instruction) |
|
|
self.tokenized_instruction = self.tokenizer.encode(self.instruction, add_special_tokens=False) |
|
|
|
|
|
|
|
|
|
|
|
self.collator = DataCollatorForCompletionOnlyLM(self.tokenized_response_w_context, tokenizer=self.tokenizer) |
|
|
self.collator.torch_call([self.tokenized_instruction]) |
|
|
|
|
|
|
|
|
|
|
|
self.collator = DataCollatorForCompletionOnlyLM( |
|
|
self.tokenized_response_w_context, self.tokenized_instruction_w_context, tokenizer=self.tokenizer |
|
|
) |
|
|
self.collator.torch_call([self.tokenized_instruction]) |
|
|
|
|
|
def test_data_collator_handling_of_long_sequences(self): |
|
|
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab") |
|
|
self.instruction = """### System: You are a helpful assistant. |
|
|
|
|
|
### User: How much is 2+2? I'm asking because I'm not sure. And I'm not sure because I'm not good at math. |
|
|
""" |
|
|
self.response_template = "\n### Assistant:" |
|
|
|
|
|
self.tokenized_instruction = self.tokenizer.encode(self.instruction, add_special_tokens=False) |
|
|
self.collator = DataCollatorForCompletionOnlyLM(self.response_template, tokenizer=self.tokenizer) |
|
|
encoded_instance = self.collator.torch_call([self.tokenized_instruction]) |
|
|
result = torch.all(encoded_instance["labels"] == -100) |
|
|
self.assertTrue(result, "Not all values in the tensor are -100.") |
|
|
|
|
|
|
|
|
self.instruction_template = "\n### User:" |
|
|
self.collator = DataCollatorForCompletionOnlyLM( |
|
|
self.response_template, self.instruction_template, tokenizer=self.tokenizer |
|
|
) |
|
|
encoded_instance = self.collator.torch_call([self.tokenized_instruction]) |
|
|
result = torch.all(encoded_instance["labels"] == -100) |
|
|
self.assertTrue(result, "Not all values in the tensor are -100.") |
|
|
|