|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import gc
|
| import json
|
| import pathlib
|
| from unittest.mock import MagicMock, patch
|
|
|
| import pytest
|
| import torch
|
| import transformers
|
| from accelerate.utils.memory import release_memory
|
| from datasets import load_dataset
|
| from packaging.version import Version
|
| from packaging.version import parse as parse_version
|
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
|
| from transformers.testing_utils import backend_empty_cache, torch_device
|
| from transformers.utils import is_peft_available
|
|
|
| from trl import SFTConfig, SFTTrainer
|
| from trl.trainer.sft_trainer import DataCollatorForLanguageModeling, dft_loss
|
|
|
| from .testing_utils import (
|
| TrlTestCase,
|
| ignore_warnings,
|
| require_ampere_or_newer,
|
| require_bitsandbytes,
|
| require_kernels,
|
| require_liger_kernel,
|
| require_peft,
|
| require_torch_accelerator,
|
| require_torch_multi_accelerator,
|
| require_vision,
|
| )
|
|
|
|
|
| if is_peft_available():
|
| import peft
|
| from peft import (
|
| LoraConfig,
|
| PeftModel,
|
| PrefixTuningConfig,
|
| PromptEncoderConfig,
|
| PromptTuningConfig,
|
| TaskType,
|
| get_peft_model,
|
| )
|
|
|
|
|
| class TestDFTLoss(TrlTestCase):
|
| def test_dft_loss(self):
|
| batch_size = 2
|
| seq_len = 3
|
| vocab_size = 2
|
|
|
| logits = torch.fill(torch.empty(batch_size, seq_len, vocab_size), torch.rand(1).item())
|
| outputs = MagicMock()
|
| outputs.logits = logits
|
| labels = torch.tensor([[1, 0, 0], [0, 1, -100]])
|
| ce_loss = torch.nn.functional.cross_entropy(
|
| logits.view(-1, vocab_size), labels.view(-1), ignore_index=-100, reduction="mean"
|
| )
|
|
|
|
|
| num_items_in_batch = 3
|
|
|
| predicted_dft_loss = dft_loss(outputs, labels, num_items_in_batch)
|
|
|
|
|
| torch.testing.assert_close(ce_loss / 2.0, predicted_dft_loss, atol=1e-4, rtol=1e-4)
|
|
|
|
|
| class TestDataCollatorForLanguageModeling(TrlTestCase):
|
| def test_basic_padding(self):
|
| """Test basic padding functionality without completion masks."""
|
| collator = DataCollatorForLanguageModeling(pad_token_id=0)
|
| examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}]
|
|
|
| result = collator(examples)
|
|
|
| assert set(result.keys()) == {"input_ids", "attention_mask", "labels"}
|
| torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]]))
|
| torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
|
| torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]]))
|
|
|
| def test_completion_mask(self):
|
| """Test completion mask functionality."""
|
| collator = DataCollatorForLanguageModeling(pad_token_id=0)
|
| examples = [
|
| {"input_ids": [1, 2, 3], "completion_mask": [0, 1, 1]},
|
| {"input_ids": [4, 5], "completion_mask": [0, 1]},
|
| ]
|
|
|
| result = collator(examples)
|
|
|
| assert set(result.keys()) == {"input_ids", "attention_mask", "labels"}
|
| torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]]))
|
| torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
|
| torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3], [-100, 5, -100]]))
|
|
|
| def test_completion_only_loss_disabled(self):
|
| """Test behavior when completion_only_loss is disabled."""
|
| collator = DataCollatorForLanguageModeling(pad_token_id=0, completion_only_loss=False)
|
| examples = [
|
| {"input_ids": [1, 2, 3], "completion_mask": [0, 1, 1]},
|
| {"input_ids": [4, 5], "completion_mask": [0, 1]},
|
| ]
|
|
|
| result = collator(examples)
|
|
|
|
|
| assert set(result.keys()) == {"input_ids", "attention_mask", "labels"}
|
| torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]]))
|
| torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
|
| torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]]))
|
|
|
| def test_padding_free_mode(self):
|
| """Test padding-free mode where sequences are concatenated."""
|
| collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True)
|
| examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}]
|
|
|
| result = collator(examples)
|
|
|
| assert set(result.keys()) == {"input_ids", "position_ids", "labels"}
|
| torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5]]))
|
| torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1]]))
|
| torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3, -100, 5]]))
|
|
|
| def test_padding_free_with_completion_mask(self):
|
| """Test padding-free mode with completion masks."""
|
| collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True)
|
| examples = [
|
| {"input_ids": [1, 2, 3], "completion_mask": [0, 0, 1]},
|
| {"input_ids": [4, 5], "completion_mask": [1, 1]},
|
| ]
|
|
|
| result = collator(examples)
|
|
|
| assert set(result.keys()) == {"input_ids", "position_ids", "labels"}
|
| torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5]]))
|
| torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1]]))
|
| torch.testing.assert_close(result["labels"], torch.tensor([[-100, -100, 3, -100, 5]]))
|
|
|
| def test_packing(self):
|
| """Test that when using packing with position_ids, attention_mask is dropped with fa2."""
|
| collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True)
|
|
|
|
|
| examples = [
|
| {"input_ids": [1, 2, 3, 4, 5, 6], "seq_lengths": [3, 3]},
|
| {"input_ids": [7, 8, 9, 10, 11], "seq_lengths": [4, 1]},
|
| ]
|
|
|
| result = collator(examples)
|
|
|
| assert set(result.keys()) == {"input_ids", "position_ids", "labels"}
|
| torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]]))
|
| torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 0]]))
|
| torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3, -100, 5, 6, -100, 8, 9, 10, -100]]))
|
|
|
| def test_pad_to_multiple_of(self):
|
| """Test padding to multiple of specified value."""
|
| collator = DataCollatorForLanguageModeling(pad_token_id=0, pad_to_multiple_of=4)
|
| examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}]
|
|
|
| result = collator(examples)
|
|
|
| assert set(result.keys()) == {"input_ids", "attention_mask", "labels"}
|
| torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 0], [4, 5, 0, 0]]))
|
| torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0]]))
|
| torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, -100], [4, 5, -100, -100]]))
|
|
|
| def test_pad_to_multiple_of_and_padding_free(self):
|
| """Test padding to multiple of specified value."""
|
| collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True, pad_to_multiple_of=4)
|
| examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}]
|
|
|
| result = collator(examples)
|
|
|
| assert set(result.keys()) == {"input_ids", "position_ids", "labels"}
|
| torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5, 0, 0, 0]]))
|
| torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1, 0, 0, 0]]))
|
| torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3, -100, 5, -100, -100, -100]]))
|
|
|
| def test_custom_position_ids_but_no_padding_free(self):
|
| """Test that custom position_ids are ignored if padding_free is False."""
|
| collator = DataCollatorForLanguageModeling(pad_token_id=0)
|
| examples = [{"input_ids": [1, 2, 3], "seq_lengths": [1, 2]}, {"input_ids": [4, 5], "seq_lengths": [2]}]
|
|
|
| result = collator(examples)
|
|
|
| assert set(result.keys()) == {"input_ids", "attention_mask", "labels"}
|
| torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]]))
|
| torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
|
| torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]]))
|
|
|
| def test_single_example(self):
|
| """Test collator with a single example."""
|
| collator = DataCollatorForLanguageModeling(pad_token_id=0)
|
| examples = [{"input_ids": [1, 2, 3, 4]}]
|
|
|
| result = collator(examples)
|
|
|
| assert set(result.keys()) == {"input_ids", "attention_mask", "labels"}
|
| torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4]]))
|
| torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1]]))
|
| torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, 4]]))
|
|
|
| def test_different_pad_token_id(self):
|
| """Test with different pad token ID."""
|
| collator = DataCollatorForLanguageModeling(pad_token_id=999)
|
| examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}]
|
|
|
| result = collator(examples)
|
|
|
| assert set(result.keys()) == {"input_ids", "attention_mask", "labels"}
|
| torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 999]]))
|
| torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
|
| torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]]))
|
|
|
| def test_assistant_masks(self):
|
| """Test handling of assistant masks in examples."""
|
| collator = DataCollatorForLanguageModeling(pad_token_id=0)
|
| examples = [
|
| {"input_ids": [1, 2, 3], "assistant_masks": [0, 1, 1]},
|
| {"input_ids": [4, 5], "assistant_masks": [0, 1]},
|
| ]
|
|
|
| result = collator(examples)
|
|
|
| torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]]))
|
| torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
|
| torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3], [-100, 5, -100]]))
|
|
|
| def test_max_length_keep_start(self):
|
| """Test that sequences longer than max_length are truncated from the start."""
|
| collator = DataCollatorForLanguageModeling(pad_token_id=0, max_length=3)
|
| examples = [{"input_ids": [1, 2, 3, 4, 5]}, {"input_ids": [6, 7, 8]}]
|
|
|
| result = collator(examples)
|
|
|
| assert set(result.keys()) == {"input_ids", "attention_mask", "labels"}
|
| torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [6, 7, 8]]))
|
| torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 1]]))
|
| torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [6, 7, 8]]))
|
|
|
| def test_max_length_keep_end(self):
|
| """Test that sequences longer than max_length are truncated from the end (keeping last tokens)."""
|
| collator = DataCollatorForLanguageModeling(pad_token_id=0, max_length=3, truncation_mode="keep_end")
|
| examples = [{"input_ids": [1, 2, 3, 4, 5]}, {"input_ids": [6, 7, 8]}]
|
|
|
| result = collator(examples)
|
|
|
| assert set(result.keys()) == {"input_ids", "attention_mask", "labels"}
|
| torch.testing.assert_close(result["input_ids"], torch.tensor([[3, 4, 5], [6, 7, 8]]))
|
| torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 1]]))
|
| torch.testing.assert_close(result["labels"], torch.tensor([[3, 4, 5], [6, 7, 8]]))
|
|
|
| def test_max_length_no_truncation_needed(self):
|
| """Test that max_length larger than sequences does not alter the output."""
|
| collator = DataCollatorForLanguageModeling(pad_token_id=0, max_length=10)
|
| examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}]
|
|
|
| result = collator(examples)
|
|
|
| assert set(result.keys()) == {"input_ids", "attention_mask", "labels"}
|
| torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]]))
|
| torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
|
| torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]]))
|
|
|
| def test_max_length_with_completion_mask(self):
|
| """Test that truncation is applied correctly when completion masks are present."""
|
| collator = DataCollatorForLanguageModeling(pad_token_id=0, max_length=3)
|
| examples = [
|
| {"input_ids": [1, 2, 3, 4, 5], "completion_mask": [0, 0, 1, 1, 1]},
|
| {"input_ids": [6, 7, 8], "completion_mask": [0, 1, 1]},
|
| ]
|
|
|
| result = collator(examples)
|
|
|
| assert set(result.keys()) == {"input_ids", "attention_mask", "labels"}
|
| torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [6, 7, 8]]))
|
| torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 1]]))
|
| torch.testing.assert_close(result["labels"], torch.tensor([[-100, -100, 3], [-100, 7, 8]]))
|
|
|
| def test_max_length_keep_end_with_completion_mask(self):
|
| """Test keep_end truncation with completion masks preserves the final tokens."""
|
| collator = DataCollatorForLanguageModeling(pad_token_id=0, max_length=3, truncation_mode="keep_end")
|
| examples = [
|
| {"input_ids": [1, 2, 3, 4, 5], "completion_mask": [0, 0, 1, 1, 1]},
|
| {"input_ids": [6, 7, 8], "completion_mask": [0, 1, 1]},
|
| ]
|
|
|
| result = collator(examples)
|
|
|
| assert set(result.keys()) == {"input_ids", "attention_mask", "labels"}
|
| torch.testing.assert_close(result["input_ids"], torch.tensor([[3, 4, 5], [6, 7, 8]]))
|
| torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 1]]))
|
| torch.testing.assert_close(result["labels"], torch.tensor([[3, 4, 5], [-100, 7, 8]]))
|
|
|
| def test_max_length_invalid_truncation_mode(self):
|
| """Test that an invalid truncation_mode raises ValueError."""
|
| collator = DataCollatorForLanguageModeling(pad_token_id=0, max_length=3, truncation_mode="invalid")
|
| examples = [{"input_ids": [1, 2, 3, 4, 5]}]
|
|
|
| with pytest.raises(ValueError, match="Unsupported truncation mode"):
|
| collator(examples)
|
|
|
| def test_single_example_single_doc(self):
|
| batch_seq_lengths = [[5]]
|
| result = DataCollatorForLanguageModeling.get_position_ids_from_packed_seq_lengths(batch_seq_lengths)
|
| assert len(result) == 1
|
| assert torch.equal(result[0], torch.arange(5))
|
|
|
| def test_single_example_multiple_docs(self):
|
| batch_seq_lengths = [[3, 2]]
|
| result = DataCollatorForLanguageModeling.get_position_ids_from_packed_seq_lengths(batch_seq_lengths)
|
| assert len(result) == 1
|
|
|
| assert torch.equal(result[0], torch.tensor([0, 1, 2, 0, 1]))
|
|
|
| def test_multiple_examples(self):
|
| batch_seq_lengths = [[2, 2], [3]]
|
| result = DataCollatorForLanguageModeling.get_position_ids_from_packed_seq_lengths(batch_seq_lengths)
|
| assert len(result) == 2
|
| assert torch.equal(result[0], torch.tensor([0, 1, 0, 1]))
|
| assert torch.equal(result[1], torch.arange(3))
|
|
|
|
|
| class TestSFTTrainer(TrlTestCase):
|
| def test_init_with_training_arguments(self):
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
|
| args = TrainingArguments(output_dir=self.tmp_dir, report_to="none")
|
| SFTTrainer(model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=args, train_dataset=dataset)
|
|
|
| @pytest.mark.parametrize(
|
| "model_id",
|
| [
|
| "trl-internal-testing/tiny-Cohere2ForCausalLM",
|
| pytest.param(
|
| "trl-internal-testing/tiny-Glm4MoeForCausalLM",
|
| marks=pytest.mark.skipif(
|
| Version(transformers.__version__) < Version("5.0.0"),
|
| reason="GLM4 tokenizer requires transformers>=5.0.0",
|
| ),
|
| ),
|
| "trl-internal-testing/tiny-GptOssForCausalLM",
|
| "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
| "trl-internal-testing/tiny-Qwen3MoeForCausalLM",
|
| ],
|
| )
|
| def test_train(self, model_id):
|
|
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
|
|
|
|
|
| training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none")
|
| trainer = SFTTrainer(model=model_id, args=training_args, train_dataset=dataset)
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
|
|
|
|
| def test_train_gpt_oss(self):
|
|
|
| dataset = load_dataset("trl-internal-testing/harmony", "language_modeling", split="train")
|
|
|
|
|
| training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none")
|
| trainer = SFTTrainer(
|
| model="trl-internal-testing/tiny-GptOssForCausalLM", args=training_args, train_dataset=dataset
|
| )
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
|
|
| def test_train_model(self):
|
|
|
| model = AutoModelForCausalLM.from_pretrained(
|
| "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
| dtype="float32",
|
| )
|
|
|
|
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
|
|
|
|
|
| training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none")
|
| trainer = SFTTrainer(model=model, args=training_args, train_dataset=dataset)
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
|
|
| def test_train_dft_loss(self):
|
|
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling")
|
|
|
|
|
| training_args = SFTConfig(
|
| output_dir=self.tmp_dir,
|
| loss_type="dft",
|
| learning_rate=0.1,
|
| report_to="none",
|
| eval_strategy="steps",
|
| eval_steps=3,
|
| )
|
| trainer = SFTTrainer(
|
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
| args=training_args,
|
| train_dataset=dataset["train"],
|
| eval_dataset=dataset["test"],
|
| )
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
|
|
| def test_train_moe_model_with_aux_loss(self):
|
|
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
|
|
|
|
|
| training_args = SFTConfig(
|
| output_dir=self.tmp_dir,
|
| report_to="none",
|
| model_init_kwargs={"output_router_logits": True},
|
| )
|
| trainer = SFTTrainer(
|
| model="trl-internal-testing/tiny-Qwen3MoeForCausalLM", args=training_args, train_dataset=dataset
|
| )
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
| assert trainer.state.log_history[-1]["aux_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
|
|
| def test_train_with_formatting_func(self):
|
|
|
| def formatting_prompts_func(example):
|
| chosen, rejected = example["chosen"], example["rejected"]
|
| return f"### Chosen: {chosen}\n### Rejected: {rejected}"
|
|
|
|
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
|
|
|
|
|
| training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none")
|
| trainer = SFTTrainer(
|
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
| args=training_args,
|
| train_dataset=dataset,
|
| formatting_func=formatting_prompts_func,
|
| )
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
|
|
| def test_train_model_dtype(self):
|
|
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
|
|
|
|
|
| training_args = SFTConfig(
|
| output_dir=self.tmp_dir,
|
| model_init_kwargs={"dtype": torch.float16},
|
| learning_rate=0.1,
|
| report_to="none",
|
| )
|
| trainer = SFTTrainer(
|
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset
|
| )
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
|
|
|
|
| if "layernorm" in n:
|
| continue
|
| new_param = trainer.model.get_parameter(n)
|
|
|
| assert new_param.dtype == torch.float16
|
| assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
|
|
| @require_peft
|
| def test_train_dense_with_peft_config_lora(self):
|
|
|
| model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
| model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32")
|
| base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
|
|
|
|
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
|
|
|
|
|
| training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none")
|
|
|
| trainer = SFTTrainer(
|
| model=model_id,
|
| args=training_args,
|
| train_dataset=dataset,
|
| peft_config=LoraConfig(),
|
| )
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| if n in base_param_names:
|
| torch.testing.assert_close(param, new_param), f"Parameter {n} has changed"
|
| elif "base_layer" not in n:
|
| assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
|
|
| @pytest.mark.parametrize(
|
| "peft_type",
|
| [
|
| "prompt_tuning",
|
| "prefix_tuning",
|
| "prompt_encoder",
|
| ],
|
| )
|
| @require_peft
|
| def test_train_with_peft_config_prompt_tuning(self, peft_type):
|
|
|
| model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
| model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32")
|
| base_param_names = [f"base_model.{n}" for n, _ in model.named_parameters()]
|
|
|
|
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
|
|
|
|
|
| training_args = SFTConfig(bf16=False, output_dir=self.tmp_dir, report_to="none", gradient_checkpointing=False)
|
| if peft_type == "prompt_tuning":
|
| peft_config = PromptTuningConfig(
|
| task_type=TaskType.CAUSAL_LM,
|
| num_virtual_tokens=4,
|
| tokenizer_name_or_path="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
| )
|
| elif peft_type == "prefix_tuning":
|
| if parse_version(peft.__version__) <= Version("0.17.1"):
|
| pytest.xfail(
|
| "Prefix tuning with device_map='auto' is broken in peft 0.17.1 and below. See "
|
| "https://github.com/huggingface/peft/issues/2821"
|
| )
|
| peft_config = PrefixTuningConfig(
|
| task_type=TaskType.CAUSAL_LM,
|
| num_virtual_tokens=4,
|
| )
|
| elif peft_type == "prompt_encoder":
|
| peft_config = PromptEncoderConfig(
|
| task_type=TaskType.CAUSAL_LM,
|
| num_virtual_tokens=4,
|
| encoder_hidden_size=model.config.hidden_size,
|
| )
|
| trainer = SFTTrainer(
|
| model=model_id,
|
| args=training_args,
|
| train_dataset=dataset,
|
| peft_config=peft_config,
|
| )
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| if n in base_param_names:
|
| torch.testing.assert_close(param, new_param), f"Parameter {n} has changed"
|
| else:
|
| assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
|
|
| @require_peft
|
| def test_train_moe_with_peft_config(self):
|
|
|
| model_id = "trl-internal-testing/tiny-GptOssForCausalLM"
|
| model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32")
|
| base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
|
|
|
|
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
|
|
|
|
|
| training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none")
|
|
|
| trainer = SFTTrainer(
|
| model=model_id,
|
| args=training_args,
|
| train_dataset=dataset,
|
| peft_config=LoraConfig(target_parameters=["mlp.experts.down_proj", "mlp.experts.gate_up_proj"]),
|
| )
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| if n in base_param_names:
|
| torch.testing.assert_close(param, new_param), f"Parameter {n} has changed"
|
| elif "base_layer" not in n:
|
| assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
|
|
| @require_peft
|
| def test_train_peft_model(self):
|
|
|
| model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
| model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32")
|
|
|
|
|
| base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
|
|
|
|
|
| lora_config = LoraConfig()
|
| model = get_peft_model(model, lora_config)
|
|
|
|
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
|
|
|
|
|
| training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none")
|
| trainer = SFTTrainer(model=model, args=training_args, train_dataset=dataset)
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| if n in base_param_names:
|
| torch.testing.assert_close(param, new_param), f"Parameter {n} has changed"
|
| elif "base_layer" not in n:
|
| assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
|
|
|
|
|
|
|
|
| @require_peft
|
| def test_train_with_peft_config_and_gradient_checkpointing(self):
|
|
|
| model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
| model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32")
|
| base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
|
|
|
|
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
|
|
|
|
|
| training_args = SFTConfig(output_dir=self.tmp_dir, gradient_checkpointing=True, report_to="none")
|
|
|
| trainer = SFTTrainer(
|
| model=model_id,
|
| args=training_args,
|
| train_dataset=dataset,
|
| peft_config=LoraConfig(),
|
| )
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| if n in base_param_names:
|
| torch.testing.assert_close(param, new_param), f"Parameter {n} has changed"
|
| elif "base_layer" not in n:
|
| assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
|
|
| @pytest.mark.parametrize("use_reentrant", [True, False])
|
| @require_peft
|
| def test_train_with_peft_config_and_gradient_checkpointing_reentrant(self, use_reentrant):
|
|
|
| model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
| model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32")
|
| base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
|
|
|
|
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
|
|
|
|
|
| training_args = SFTConfig(
|
| output_dir=self.tmp_dir,
|
| gradient_checkpointing=True,
|
| gradient_checkpointing_kwargs={"use_reentrant": use_reentrant},
|
| report_to="none",
|
| )
|
|
|
| trainer = SFTTrainer(
|
| model=model_id,
|
| args=training_args,
|
| train_dataset=dataset,
|
| peft_config=LoraConfig(),
|
| )
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| if n in base_param_names:
|
| torch.testing.assert_close(param, new_param), f"Parameter {n} has changed"
|
| elif "base_layer" not in n:
|
| assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
|
|
| @require_liger_kernel
|
| def test_train_with_liger(self):
|
|
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
|
|
|
|
|
| training_args = SFTConfig(output_dir=self.tmp_dir, use_liger_kernel=True, report_to="none")
|
| trainer = SFTTrainer(
|
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset
|
| )
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
|
|
| @require_torch_accelerator
|
| @require_liger_kernel
|
| def test_compute_loss_skip_logits_on_eval_without_metrics_with_liger(self):
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train[:1]")
|
|
|
| training_args = SFTConfig(
|
| output_dir=self.tmp_dir,
|
| use_liger_kernel=False,
|
| report_to="none",
|
| max_length=8,
|
| bf16=False,
|
| )
|
| trainer = SFTTrainer(
|
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
| args=training_args,
|
| train_dataset=dataset,
|
| compute_metrics=None,
|
| )
|
| trainer.args.use_liger_kernel = True
|
| trainer.model.eval()
|
|
|
| captured = {}
|
|
|
| def mock_super_compute_loss(model, inputs, return_outputs=False, num_items_in_batch=None):
|
| captured["skip_logits"] = inputs.get("skip_logits")
|
| dummy_loss = torch.tensor(1.0, requires_grad=True)
|
| dummy_outputs = MagicMock()
|
| dummy_outputs.token_accuracy = None
|
| dummy_outputs.logits = torch.randn(1, 5, trainer.model.config.vocab_size)
|
| return (dummy_loss, dummy_outputs)
|
|
|
| inputs = {
|
| "input_ids": torch.tensor([[1, 2, 3, 4, 5]]),
|
| "labels": torch.tensor([[1, 2, 3, 4, 5]]),
|
| "attention_mask": torch.tensor([[1, 1, 1, 1, 1]]),
|
| }
|
|
|
| with patch("transformers.Trainer.compute_loss", side_effect=mock_super_compute_loss):
|
| trainer.compute_loss(trainer.model, inputs)
|
|
|
| assert captured["skip_logits"] is True
|
|
|
| @require_torch_accelerator
|
| @require_liger_kernel
|
| def test_predict_does_not_skip_logits_with_liger(self):
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train[:1]")
|
|
|
| training_args = SFTConfig(
|
| output_dir=self.tmp_dir,
|
| use_liger_kernel=False,
|
| report_to="none",
|
| max_length=8,
|
| bf16=False,
|
| )
|
| trainer = SFTTrainer(
|
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
| args=training_args,
|
| train_dataset=dataset,
|
| compute_metrics=None,
|
| )
|
| trainer.args.use_liger_kernel = True
|
| trainer.model.eval()
|
|
|
| captured = {}
|
|
|
| def mock_super_compute_loss(model, inputs, return_outputs=False, num_items_in_batch=None):
|
| captured["skip_logits"] = inputs.get("skip_logits")
|
| dummy_loss = torch.tensor(1.0, requires_grad=True)
|
| dummy_outputs = (dummy_loss, torch.randn(1, 5, trainer.model.config.vocab_size))
|
| return (dummy_loss, dummy_outputs)
|
|
|
| with patch("transformers.Trainer.compute_loss", side_effect=mock_super_compute_loss):
|
| trainer.predict(trainer.train_dataset)
|
|
|
| assert captured["skip_logits"] is False
|
|
|
| def test_train_with_non_chatml_conversational_data(self):
|
|
|
| dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train")
|
|
|
|
|
| def rename_fields(example: list[dict]):
|
| return {"conversations": [{"from": m["role"], "value": m["content"]} for m in example["messages"]]}
|
|
|
| dataset = dataset.map(rename_fields, remove_columns="messages")
|
|
|
|
|
| training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none")
|
| trainer = SFTTrainer(
|
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset
|
| )
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
|
|
| def test_train_with_pretokenized_data(self):
|
|
|
| model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
| tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
|
|
|
| def tokenize_example(example):
|
| return tokenizer(example["text"])
|
|
|
|
|
| tokenized_dataset = dataset.map(tokenize_example, remove_columns=["text"])
|
|
|
|
|
| training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none")
|
| trainer = SFTTrainer(model=model_id, args=training_args, train_dataset=tokenized_dataset)
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
|
|
| def test_skip_prepare_dataset_passes_truncation_to_text_collator(self):
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train[:2]")
|
| with pytest.warns(FutureWarning, match="keep_end.*deprecated"):
|
| training_args = SFTConfig(
|
| output_dir=self.tmp_dir,
|
| max_length=16,
|
| truncation_mode="keep_end",
|
| dataset_kwargs={"skip_prepare_dataset": True},
|
| report_to="none",
|
| )
|
|
|
| trainer = SFTTrainer(
|
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset
|
| )
|
|
|
| assert isinstance(trainer.data_collator, DataCollatorForLanguageModeling)
|
| assert trainer.data_collator.max_length == 16
|
| assert trainer.data_collator.truncation_mode == "keep_end"
|
|
|
| def test_padding_free_without_packing_and_max_length_raises(self):
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train[:2]")
|
| training_args = SFTConfig(
|
| output_dir=self.tmp_dir,
|
| max_length=16,
|
| padding_free=True,
|
| report_to="none",
|
| )
|
|
|
| with pytest.raises(ValueError, match="`max_length` is not enforced"):
|
| SFTTrainer(
|
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset
|
| )
|
|
|
| def test_train_with_iterable_dataset(self):
|
|
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train", streaming=True)
|
|
|
|
|
| training_args = SFTConfig(output_dir=self.tmp_dir, max_steps=3, report_to="none")
|
| trainer = SFTTrainer(
|
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset
|
| )
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
|
|
| @require_kernels
|
| @require_ampere_or_newer
|
| def test_train_padding_free(self):
|
|
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
|
|
|
|
|
| training_args = SFTConfig(
|
| output_dir=self.tmp_dir,
|
| padding_free=True,
|
| model_init_kwargs={"attn_implementation": "kernels-community/flash-attn2"},
|
| bf16=True,
|
| report_to="none",
|
| )
|
| trainer = SFTTrainer(
|
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset
|
| )
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
|
|
| @pytest.mark.parametrize("packing_strategy", ["bfd", "wrapped"])
|
| @ignore_warnings(message="You are using packing, but the attention implementation is not.*", category=UserWarning)
|
| @ignore_warnings(message="Padding-free training is enabled, but the attention.*", category=UserWarning)
|
| def test_train_packing(self, packing_strategy):
|
|
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
|
|
|
|
|
| training_args = SFTConfig(
|
| output_dir=self.tmp_dir, packing=True, packing_strategy=packing_strategy, max_length=10, report_to="none"
|
| )
|
| trainer = SFTTrainer(
|
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset
|
| )
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
|
|
| @ignore_warnings(message="You are using packing, but the attention implementation is not.*", category=UserWarning)
|
| @ignore_warnings(message="Padding-free training is enabled, but the attention.*", category=UserWarning)
|
| def test_eval_packing(self):
|
|
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling")
|
|
|
|
|
| training_args = SFTConfig(
|
| output_dir=self.tmp_dir,
|
| packing=True,
|
| max_length=64,
|
| report_to="none",
|
| )
|
| trainer = SFTTrainer(
|
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
| args=training_args,
|
| train_dataset=dataset["train"],
|
| eval_dataset=dataset["test"],
|
| )
|
|
|
|
|
| num_train_seqs = sum(len(x) for x in trainer.train_dataset["seq_lengths"])
|
| num_eval_seqs = sum(len(x) for x in trainer.eval_dataset["seq_lengths"])
|
| assert num_train_seqs == 17
|
| assert num_eval_seqs == 2
|
|
|
|
|
| assert all(sum(x) <= 64 for x in trainer.train_dataset["seq_lengths"])
|
| assert all(sum(x) <= 64 for x in trainer.eval_dataset["seq_lengths"])
|
|
|
|
|
| assert len(trainer.train_dataset["input_ids"]) == 3
|
| assert len(trainer.eval_dataset["input_ids"]) == 1
|
|
|
| @ignore_warnings(message="You are using packing, but the attention implementation is not.*", category=UserWarning)
|
| @ignore_warnings(message="Padding-free training is enabled, but the attention.*", category=UserWarning)
|
| def test_only_train_packing(self):
|
|
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling")
|
|
|
|
|
| training_args = SFTConfig(
|
| output_dir=self.tmp_dir,
|
| packing=True,
|
| eval_packing=False,
|
| max_length=64,
|
| report_to="none",
|
| )
|
| trainer = SFTTrainer(
|
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
| args=training_args,
|
| train_dataset=dataset["train"],
|
| eval_dataset=dataset["test"],
|
| )
|
|
|
|
|
| num_train_seqs = sum(len(x) for x in trainer.train_dataset["seq_lengths"])
|
| assert num_train_seqs == 17
|
|
|
|
|
| assert "seq_lengths" not in trainer.eval_dataset
|
|
|
|
|
| assert all(sum(x) <= 64 for x in trainer.train_dataset["seq_lengths"])
|
|
|
|
|
| assert len(trainer.train_dataset["input_ids"]) == 3
|
| assert len(trainer.eval_dataset["input_ids"]) == 2
|
|
|
| def test_train_with_chat_template_kwargs(self):
|
|
|
| dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train")
|
|
|
|
|
| training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none")
|
|
|
| tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
|
|
|
|
|
| tokenizer.chat_template = '{%- if messages[0]["role"] == "system" -%} {{ "<|im_start|>" + ("SYSTEM" if role_capital else "system") + "\\n" + messages[0]["content"] + "<|im_end|>\\n" }}{%- else -%} {{ "<|im_start|>" + ("SYSTEM" if role_capital else "system") + "\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n" }}{%- endif -%}{%- for message in messages -%} {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) -%} {{ "<|im_start|>" + (message.role.upper() if role_capital else message.role) + "\\n" + message.content + "<|im_end|>\\n" }} {%- elif message.role == "assistant" -%} {{ "<|im_start|>" + ("ASSISTANT" if role_capital else "assistant") }} {%- if message.content -%} {{ "\\n" + message.content }} {%- endif -%} {{ "<|im_end|>\\n" }} {%- elif message.role == "tool" -%} {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") -%} {{ "<|im_start|>" + ("USER" if role_capital else "user") }} {%- endif -%} {{ "\\n<tool_response>\\n" + message.content + "\\n</tool_response>" }} {%- if loop.last or (messages[loop.index0 + 1].role != "tool") -%} {{ "<|im_end|>\\n" }} {%- endif -%} {%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%} {{ "<|im_start|>" + ("ASSISTANT" if role_capital else "assistant") + "\\n" }}{%- endif -%}'
|
|
|
| dataset = dataset.add_column(
|
| "chat_template_kwargs", [{"role_capital": bool(i % 2)} for i in range(len(dataset))]
|
| )
|
| assert "chat_template_kwargs" in dataset.features
|
|
|
| trainer = SFTTrainer(
|
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
| args=training_args,
|
| train_dataset=dataset,
|
| processing_class=tokenizer,
|
| )
|
|
|
|
|
| assert trainer.processing_class.chat_template == tokenizer.chat_template
|
|
|
|
|
| for i in range(2):
|
| role = "SYSTEM" if i else "system"
|
| system_prompt = (
|
| f"<|im_start|>{role}\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>"
|
| )
|
| system_prompt_ids = trainer.processing_class(system_prompt)["input_ids"]
|
| assert trainer.train_dataset[i]["input_ids"][: len(system_prompt_ids)] == system_prompt_ids
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
|
|
| def test_train_assistant_only(self):
|
|
|
| dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train")
|
|
|
|
|
| training_args = SFTConfig(output_dir=self.tmp_dir, assistant_only_loss=True, report_to="none")
|
| trainer = SFTTrainer(
|
| model="trl-internal-testing/tiny-Qwen3ForCausalLM", args=training_args, train_dataset=dataset
|
| )
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
|
|
| def test_train_completion_only(self):
|
|
|
| dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_completion", split="train")
|
|
|
|
|
| training_args = SFTConfig(output_dir=self.tmp_dir, completion_only_loss=True, report_to="none")
|
| trainer = SFTTrainer(
|
| model="trl-internal-testing/tiny-Qwen3ForCausalLM", args=training_args, train_dataset=dataset
|
| )
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
|
|
| def test_train_completion_only_harmony(self):
|
|
|
| dataset = load_dataset("trl-internal-testing/harmony", "prompt_completion", split="train")
|
|
|
|
|
| training_args = SFTConfig(output_dir=self.tmp_dir, completion_only_loss=True, report_to="none")
|
| trainer = SFTTrainer(
|
| model="trl-internal-testing/tiny-GptOssForCausalLM", args=training_args, train_dataset=dataset
|
| )
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
|
|
| def test_train_assistant_only_and_completion_only(self):
|
|
|
| dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_completion", split="train")
|
|
|
|
|
| def add_to_completion(example):
|
| example["completion"].append(example["prompt"][0])
|
| example["completion"].append(example["completion"][0])
|
| return example
|
|
|
| dataset = dataset.map(add_to_completion)
|
|
|
|
|
| training_args = SFTConfig(
|
| output_dir=self.tmp_dir, assistant_only_loss=True, completion_only_loss=True, report_to="none"
|
| )
|
| trainer = SFTTrainer(
|
| model="trl-internal-testing/tiny-Qwen3ForCausalLM", args=training_args, train_dataset=dataset
|
| )
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
|
|
| def test_train_assistant_only_iterable_dataset(self):
|
|
|
| dataset = load_dataset(
|
| "trl-internal-testing/zen", "conversational_language_modeling", split="train", streaming=True
|
| )
|
|
|
|
|
| training_args = SFTConfig(output_dir=self.tmp_dir, assistant_only_loss=True, max_steps=3, report_to="none")
|
| trainer = SFTTrainer(
|
| model="trl-internal-testing/tiny-Qwen3ForCausalLM", args=training_args, train_dataset=dataset
|
| )
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
|
|
| def test_train_with_set_chat_template_from_model(self):
|
|
|
| dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train")
|
|
|
|
|
| training_args = SFTConfig(output_dir=self.tmp_dir, chat_template_path="Qwen/Qwen3-4B", report_to="none")
|
|
|
| trainer = SFTTrainer(
|
| model="trl-internal-testing/tiny-GPTNeoXForCausalLM", args=training_args, train_dataset=dataset
|
| )
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
|
|
| def test_train_with_set_chat_template_from_path(self, lazy_shared_datadir):
|
|
|
| dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train")
|
|
|
|
|
| training_args = SFTConfig(
|
| output_dir=self.tmp_dir,
|
| chat_template_path=str(lazy_shared_datadir / "template.jinja"),
|
| report_to="none",
|
| )
|
|
|
| trainer = SFTTrainer(
|
| model="trl-internal-testing/tiny-GPTNeoXForCausalLM", args=training_args, train_dataset=dataset
|
| )
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
|
|
|
|
| template_path = pathlib.Path(self.tmp_dir) / "checkpoint-9" / "chat_template.jinja"
|
| assert template_path.exists(), f"Chat template not found at {template_path}"
|
|
|
| with open(template_path) as f:
|
| template_content = f.read()
|
| with open(training_args.chat_template_path) as f:
|
| original_template_content = f.read()
|
| assert template_content == original_template_content, "Chat template content does not match the original"
|
|
|
| def test_train_toolcall_data(self):
|
|
|
| dataset = load_dataset("trl-internal-testing/toolcall", "language_modeling", split="train")
|
|
|
|
|
| training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none")
|
| trainer = SFTTrainer(
|
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset
|
| )
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
|
|
| def test_train_toolcall_data_as_json(self):
|
|
|
|
|
|
|
|
|
|
|
| dataset = load_dataset("trl-internal-testing/toolcall", "language_modeling", split="train")
|
|
|
| def convert_to_json(example):
|
| return {"tools": json.loads(example["tools"])}
|
|
|
| dataset = dataset.map(convert_to_json)
|
|
|
|
|
| training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none")
|
| trainer = SFTTrainer(
|
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset
|
| )
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
|
|
| def test_train_with_eval(self):
|
|
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling")
|
|
|
|
|
| training_args = SFTConfig(output_dir=self.tmp_dir, eval_strategy="steps", eval_steps=3, report_to="none")
|
| trainer = SFTTrainer(
|
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
| args=training_args,
|
| train_dataset=dataset["train"],
|
| eval_dataset=dataset["test"],
|
| )
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[0]["eval_loss"] is not None
|
|
|
| def test_train_with_multiple_eval_dataset(self):
|
|
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling")
|
|
|
|
|
| training_args = SFTConfig(output_dir=self.tmp_dir, eval_strategy="steps", eval_steps=3, report_to="none")
|
| trainer = SFTTrainer(
|
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
| args=training_args,
|
| train_dataset=dataset["train"],
|
| eval_dataset={"data1": dataset["test"], "data2": dataset["test"]},
|
| )
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-3]["eval_data1_loss"] is not None
|
| assert trainer.state.log_history[-2]["eval_data2_loss"] is not None
|
|
|
| def test_train_with_compute_metrics(self):
|
|
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling")
|
|
|
| def dummy_compute_metrics(eval_pred):
|
| return {"my_metric": 0.123}
|
|
|
|
|
| training_args = SFTConfig(
|
| output_dir=self.tmp_dir,
|
| eval_strategy="steps",
|
| eval_steps=3,
|
| report_to="none",
|
| )
|
| trainer = SFTTrainer(
|
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
| args=training_args,
|
| train_dataset=dataset["train"],
|
| eval_dataset=dataset["test"],
|
| compute_metrics=dummy_compute_metrics,
|
| )
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-2]["eval_my_metric"] == 0.123
|
|
|
|
|
|
|
|
|
| def test_train_with_gradient_checkpointing(self):
|
|
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
|
|
|
|
|
| training_args = SFTConfig(output_dir=self.tmp_dir, gradient_checkpointing=True, report_to="none")
|
| trainer = SFTTrainer(
|
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset
|
| )
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
|
|
| @pytest.mark.parametrize("use_reentrant", [True, False])
|
| def test_train_with_gradient_checkpointing_reentrant(self, use_reentrant):
|
|
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
|
|
|
|
|
| training_args = SFTConfig(
|
| output_dir=self.tmp_dir,
|
| gradient_checkpointing=True,
|
| gradient_checkpointing_kwargs={"use_reentrant": use_reentrant},
|
| report_to="none",
|
| )
|
| trainer = SFTTrainer(
|
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset
|
| )
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
|
|
| def test_tag_added(self):
|
|
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
|
|
|
|
|
| trainer = SFTTrainer(
|
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
| train_dataset=dataset,
|
| )
|
|
|
| for tag in ["sft", "trl"]:
|
| assert tag in trainer.model.model_tags
|
|
|
| @require_peft
|
| def test_tag_added_peft(self):
|
|
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
|
|
|
|
|
| trainer = SFTTrainer(
|
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
| train_dataset=dataset,
|
| peft_config=LoraConfig(),
|
| )
|
|
|
| for tag in ["sft", "trl"]:
|
| assert tag in trainer.model.model_tags
|
|
|
| @pytest.mark.parametrize(
|
| "model_id",
|
| [
|
| "trl-internal-testing/tiny-Gemma3ForConditionalGeneration",
|
| pytest.param(
|
| "trl-internal-testing/tiny-Gemma4ForConditionalGeneration",
|
| marks=pytest.mark.skipif(
|
| Version(transformers.__version__) < Version("5.5.0"),
|
| reason="Gemma4 models were introduced in transformers-5.5.0",
|
| ),
|
| ),
|
|
|
|
|
| "trl-internal-testing/tiny-LlavaForConditionalGeneration",
|
| "trl-internal-testing/tiny-LlavaNextForConditionalGeneration",
|
| "trl-internal-testing/tiny-Qwen2VLForConditionalGeneration",
|
| "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
|
|
|
| pytest.param(
|
| "trl-internal-testing/tiny-Qwen3VLForConditionalGeneration",
|
| marks=[
|
| pytest.mark.skipif(
|
| Version(transformers.__version__) < Version("4.57.0"),
|
| reason="Qwen3-VL series were introduced in transformers-4.57.0",
|
| ),
|
| ],
|
| ),
|
| pytest.param(
|
| "trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration",
|
| marks=pytest.mark.skipif(
|
| Version(transformers.__version__) < Version("5.2.0"),
|
| reason="Qwen3.5 models were introduced in transformers-5.2.0",
|
| ),
|
| ),
|
| ],
|
| )
|
| @require_vision
|
| def test_train_vlm(self, model_id):
|
|
|
| dataset = load_dataset("trl-internal-testing/zen-image", "conversational_language_modeling", split="train")
|
|
|
|
|
| training_args = SFTConfig(
|
| output_dir=self.tmp_dir,
|
| max_length=None,
|
| report_to="none",
|
| )
|
| trainer = SFTTrainer(model=model_id, args=training_args, train_dataset=dataset)
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
|
|
|
|
|
|
| if (
|
| model_id == "trl-internal-testing/tiny-Gemma3ForConditionalGeneration" and "model.vision_tower.vision_model.head" in n or
|
| model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and "model.vision_tower.vision_model.post_layernorm" in n or
|
| model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and "vision_tower.vision_model.encoder.layers.1" in n or
|
| model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and "model.vision_tower.encoder.layers.1" in n or
|
| model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and "model.vision_tower.post_layernorm" in n or
|
| model_id == "trl-internal-testing/tiny-LlavaNextForConditionalGeneration" and "model.vision_tower.vision_model.post_layernorm" in n or
|
| model_id == "trl-internal-testing/tiny-LlavaNextForConditionalGeneration" and "vision_tower.vision_model.encoder.layers.1" in n or
|
| model_id == "trl-internal-testing/tiny-LlavaNextForConditionalGeneration" and "model.vision_tower.encoder.layers.1" in n or
|
| model_id == "trl-internal-testing/tiny-LlavaNextForConditionalGeneration" and "model.vision_tower.post_layernorm" in n or
|
| model_id == "trl-internal-testing/tiny-Qwen3VLForConditionalGeneration" and "model.visual.deepstack_merger_list" in n
|
| ):
|
|
|
| continue
|
| assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"
|
|
|
| @pytest.mark.parametrize(
|
| "model_id",
|
| [
|
| "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
|
| ],
|
| )
|
| @pytest.mark.xfail(
|
| parse_version(transformers.__version__) < parse_version("4.57.0"),
|
| reason="Mixing text-only and image+text examples is only supported in transformers >= 4.57.0",
|
| strict=False,
|
| )
|
| @require_vision
|
| def test_train_vlm_multi_image(self, model_id):
|
|
|
| dataset = load_dataset(
|
| "trl-internal-testing/zen-multi-image", "conversational_prompt_completion", split="train"
|
| )
|
|
|
|
|
| training_args = SFTConfig(
|
| output_dir=self.tmp_dir,
|
| learning_rate=0.1,
|
| max_length=None,
|
| report_to="none",
|
| )
|
| trainer = SFTTrainer(
|
| model=model_id,
|
| args=training_args,
|
| train_dataset=dataset,
|
| )
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"
|
|
|
| @pytest.mark.parametrize(
|
| "model_id",
|
| [
|
| "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
|
|
|
| "trl-internal-testing/tiny-Gemma3ForConditionalGeneration",
|
| pytest.param(
|
| "trl-internal-testing/tiny-Gemma4ForConditionalGeneration",
|
| marks=pytest.mark.skipif(
|
| Version(transformers.__version__) < Version("5.5.0"),
|
| reason="Gemma4 models were introduced in transformers-5.5.0",
|
| ),
|
| ),
|
| ],
|
| )
|
| @require_vision
|
| def test_train_vlm_prompt_completion(self, model_id):
|
|
|
| dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_completion", split="train")
|
|
|
|
|
| training_args = SFTConfig(
|
| output_dir=self.tmp_dir,
|
| learning_rate=0.1,
|
| max_length=None,
|
| report_to="none",
|
| )
|
| trainer = SFTTrainer(
|
| model=model_id,
|
| args=training_args,
|
| train_dataset=dataset,
|
| )
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"
|
|
|
|
|
|
|
| @pytest.mark.slow
|
| @require_vision
|
| @pytest.mark.skip(reason="Model google/gemma-3n-E2B-it is gated and requires HF token")
|
| def test_train_vlm_gemma_3n(self):
|
|
|
| dataset = load_dataset("trl-internal-testing/zen-image", "conversational_language_modeling", split="train")
|
|
|
|
|
| training_args = SFTConfig(
|
| output_dir=self.tmp_dir,
|
| learning_rate=0.1,
|
| max_length=None,
|
| per_device_train_batch_size=1,
|
| model_init_kwargs={"dtype": "bfloat16"},
|
| report_to="none",
|
| )
|
| trainer = SFTTrainer(model="google/gemma-3n-E2B-it", args=training_args, train_dataset=dataset)
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| if "model.audio_tower" in n or "model.embed_audio" in n:
|
|
|
| continue
|
| assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"
|
|
|
| @pytest.mark.parametrize(
|
| "model_id",
|
| [
|
| "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
|
| ],
|
| )
|
| @pytest.mark.parametrize(
|
| "dataset_config",
|
| [
|
| "conversational_language_modeling",
|
| "conversational_prompt_completion",
|
| "standard_language_modeling",
|
| "standard_prompt_completion",
|
| ],
|
| )
|
| @require_vision
|
| def test_train_vlm_text_only_data(self, model_id, dataset_config):
|
|
|
| dataset = load_dataset("trl-internal-testing/zen", dataset_config, split="train")
|
|
|
|
|
| training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none")
|
| trainer = SFTTrainer(
|
| model=model_id,
|
| args=training_args,
|
| train_dataset=dataset,
|
| )
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| if n.startswith("model.visual"):
|
| torch.testing.assert_close(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is updated"
|
| else:
|
| assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"
|
|
|
| @require_peft
|
| def test_prompt_tuning(self):
|
| """Test that SFT works with Prompt Tuning."""
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
|
|
|
| training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none")
|
| trainer = SFTTrainer(
|
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
| args=training_args,
|
| train_dataset=dataset,
|
| peft_config=PromptEncoderConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=8),
|
| )
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
| assert trainer.state.log_history[-1]["mean_token_accuracy"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| if "base_model" in n:
|
| torch.testing.assert_close(param, new_param), f"Parameter {n} has changed"
|
| elif "prompt_encoder" in n:
|
| assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
| else:
|
| raise ValueError(f"Unexpected parameter {n} in model: {trainer.model}")
|
|
|
| @require_peft
|
| @require_bitsandbytes
|
| def test_peft_with_quantization(self):
|
|
|
| model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
|
|
| quantization_config = BitsAndBytesConfig(
|
| load_in_4bit=True,
|
| bnb_4bit_use_double_quant=True,
|
| bnb_4bit_quant_type="nf4",
|
| bnb_4bit_compute_dtype=torch.float16,
|
| )
|
| model = AutoModelForCausalLM.from_pretrained(
|
| model_id,
|
| dtype="float32",
|
| quantization_config=quantization_config,
|
| )
|
|
|
|
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
|
|
|
|
|
| training_args = SFTConfig(output_dir=self.tmp_dir, learning_rate=0.1, report_to="none")
|
| trainer = SFTTrainer(model=model, args=training_args, train_dataset=dataset, peft_config=LoraConfig())
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
| assert trainer.state.log_history[-1]["mean_token_accuracy"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
|
|
|
|
|
|
|
|
|
|
| if n in [
|
| "base_model.model.model.layers.1.self_attn.k_proj.bias",
|
| "base_model.model.model.layers.1.self_attn.q_proj.base_layer.bias",
|
| "base_model.model.model.layers.1.self_attn.v_proj.base_layer.bias",
|
| ]:
|
| param = param.float()
|
|
|
| if "lora" not in n:
|
| torch.testing.assert_close(param, new_param), f"Parameter {n} has changed"
|
| elif "lora" in n:
|
| assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
| else:
|
| raise ValueError(f"Unexpected parameter {n} in model: {trainer.model}")
|
|
|
| @require_peft
|
| def test_prompt_tuning_peft_model(self):
|
| """Test that SFT works with Prompt Tuning and a pre-converted PeftModel"""
|
| model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", dtype="float32")
|
| model = get_peft_model(model, PromptEncoderConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=8))
|
|
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
|
|
|
| training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none")
|
| trainer = SFTTrainer(model=model, args=training_args, train_dataset=dataset)
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
| assert trainer.state.log_history[-1]["mean_token_accuracy"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| if "base_model" in n:
|
| torch.testing.assert_close(param, new_param), f"Parameter {n} has changed"
|
| elif "prompt_encoder" in n:
|
| assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
| else:
|
| raise ValueError(f"Unexpected parameter {n} in model: {trainer.model}")
|
|
|
|
|
| @pytest.mark.slow
|
| @require_torch_accelerator
|
| @require_peft
|
| class TestSFTTrainerSlow(TrlTestCase):
|
| def setup_method(self):
|
| self.train_dataset = load_dataset("stanfordnlp/imdb", split="train[:10%]")
|
| self.eval_dataset = load_dataset("stanfordnlp/imdb", split="test[:10%]")
|
| self.max_length = 128
|
| self.peft_config = LoraConfig(
|
| lora_alpha=16,
|
| lora_dropout=0.1,
|
| r=8,
|
| bias="none",
|
| task_type="CAUSAL_LM",
|
| )
|
|
|
| def teardown_method(self):
|
| gc.collect()
|
| backend_empty_cache(torch_device)
|
| gc.collect()
|
|
|
| @pytest.mark.parametrize("packing", [True, False])
|
| @pytest.mark.parametrize(
|
| "model_name",
|
| [
|
| "trl-internal-testing/tiny-LlamaForCausalLM-3.2",
|
| "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
| ],
|
| )
|
| def test_sft_trainer_transformers_mp(self, model_name, packing):
|
| """
|
| Simply tests if passing a transformers model to `SFTTrainer` loads and runs the trainer as expected in mixed
|
| precision.
|
| """
|
| training_args = SFTConfig(
|
| output_dir=self.tmp_dir,
|
| logging_strategy="no",
|
| report_to="none",
|
| per_device_train_batch_size=2,
|
| max_steps=10,
|
| fp16=True,
|
| packing=packing,
|
| max_length=self.max_length,
|
| )
|
|
|
| model = AutoModelForCausalLM.from_pretrained(model_name, dtype="float32")
|
| tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
| trainer = SFTTrainer(
|
| model,
|
| args=training_args,
|
| processing_class=tokenizer,
|
| train_dataset=self.train_dataset,
|
| eval_dataset=self.eval_dataset,
|
| )
|
|
|
| trainer.train()
|
|
|
| release_memory(model, trainer)
|
|
|
| @pytest.mark.parametrize("device_map", [{"": 0}, "auto"])
|
| @pytest.mark.parametrize(
|
| "gradient_checkpointing_kwargs", [None, {"use_reentrant": False}, {"use_reentrant": True}]
|
| )
|
| @pytest.mark.parametrize("packing", [True, False])
|
| @pytest.mark.parametrize(
|
| "model_name",
|
| [
|
| "trl-internal-testing/tiny-LlamaForCausalLM-3.2",
|
| "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
| ],
|
| )
|
| @require_torch_multi_accelerator
|
| def test_sft_trainer_transformers_mp_gc_device_map(
|
| self, model_name, packing, gradient_checkpointing_kwargs, device_map
|
| ):
|
| """
|
| Simply tests if passing a transformers model to `SFTTrainer` loads and runs the trainer as expected in mixed
|
| precision + different scenarios of gradient_checkpointing (single, multi-gpu, etc).
|
| """
|
| training_args = SFTConfig(
|
| output_dir=self.tmp_dir,
|
| logging_strategy="no",
|
| report_to="none",
|
| per_device_train_batch_size=2,
|
| max_steps=10,
|
| packing=packing,
|
| max_length=self.max_length,
|
| fp16=True,
|
| gradient_checkpointing=True,
|
| gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
|
| )
|
|
|
| model = AutoModelForCausalLM.from_pretrained(model_name, dtype="float32", device_map=device_map)
|
| tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
| trainer = SFTTrainer(
|
| model,
|
| args=training_args,
|
| processing_class=tokenizer,
|
| train_dataset=self.train_dataset,
|
| eval_dataset=self.eval_dataset,
|
| )
|
|
|
| trainer.train()
|
|
|
| release_memory(model, trainer)
|
|
|
| @pytest.mark.parametrize(
|
| "gradient_checkpointing_kwargs", [None, {"use_reentrant": False}, {"use_reentrant": True}]
|
| )
|
| @pytest.mark.parametrize("packing", [True, False])
|
| @pytest.mark.parametrize(
|
| "model_name",
|
| [
|
| "trl-internal-testing/tiny-LlamaForCausalLM-3.2",
|
| "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
| ],
|
| )
|
| @require_peft
|
| @require_bitsandbytes
|
| def test_sft_trainer_transformers_mp_gc_peft_qlora(self, model_name, packing, gradient_checkpointing_kwargs):
|
| """
|
| Simply tests if passing a transformers model + PEFT + bnb to `SFTTrainer` loads and runs the trainer as
|
| expected in mixed precision + different scenarios of gradient_checkpointing.
|
| """
|
| training_args = SFTConfig(
|
| output_dir=self.tmp_dir,
|
| logging_strategy="no",
|
| report_to="none",
|
| per_device_train_batch_size=2,
|
| max_steps=10,
|
| packing=packing,
|
| max_length=self.max_length,
|
| gradient_checkpointing=True,
|
| gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
|
| )
|
|
|
| quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
|
|
|
| model = AutoModelForCausalLM.from_pretrained(
|
| model_name, dtype="float32", quantization_config=quantization_config
|
| )
|
| tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
| trainer = SFTTrainer(
|
| model,
|
| args=training_args,
|
| processing_class=tokenizer,
|
| train_dataset=self.train_dataset,
|
| eval_dataset=self.eval_dataset,
|
| peft_config=self.peft_config,
|
| )
|
|
|
| assert isinstance(trainer.model, PeftModel)
|
|
|
| trainer.train()
|
|
|
| release_memory(model, trainer)
|
|
|
| @pytest.mark.parametrize("packing", [True, False])
|
| @pytest.mark.parametrize(
|
| "model_name",
|
| [
|
| "trl-internal-testing/tiny-LlamaForCausalLM-3.2",
|
| "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
| ],
|
| )
|
| @require_peft
|
| @require_bitsandbytes
|
| def test_sft_trainer_with_chat_format_qlora(self, model_name, packing):
|
| """
|
| Simply tests if using setup_chat_format with a transformers model + peft + bnb config to `SFTTrainer` loads and
|
| runs the trainer as expected.
|
| """
|
| train_dataset = load_dataset("trl-internal-testing/dolly-chatml-sft", split="train")
|
|
|
| training_args = SFTConfig(
|
| packing=packing,
|
| max_length=self.max_length,
|
| output_dir=self.tmp_dir,
|
| logging_strategy="no",
|
| report_to="none",
|
| per_device_train_batch_size=2,
|
| max_steps=10,
|
| )
|
|
|
| quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
|
|
|
| model = AutoModelForCausalLM.from_pretrained(
|
| model_name, dtype="float32", quantization_config=quantization_config
|
| )
|
| tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
| trainer = SFTTrainer(
|
| model,
|
| args=training_args,
|
| processing_class=tokenizer,
|
| train_dataset=train_dataset,
|
| peft_config=self.peft_config,
|
| )
|
|
|
| assert isinstance(trainer.model, PeftModel)
|
|
|
| trainer.train()
|
|
|
| release_memory(model, trainer)
|
|
|
| @pytest.mark.parametrize("packing", [True, False])
|
| @pytest.mark.parametrize(
|
| "model_name",
|
| [
|
| "trl-internal-testing/tiny-LlamaForCausalLM-3.2",
|
| "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
| ],
|
| )
|
| @require_liger_kernel
|
| def test_sft_trainer_with_liger(self, model_name, packing):
|
| """
|
| Tests if passing use_liger=True to SFTConfig loads and runs the trainer with AutoLigerKernelForCausalLM as
|
| expected.
|
| """
|
| import importlib
|
|
|
| def cleanup_liger_patches(trainer):
|
| """Clean up liger_kernel patches by reloading the model's specific module"""
|
| try:
|
|
|
| module_path = trainer.model.__module__
|
| reload_module = importlib.import_module(module_path)
|
| importlib.reload(reload_module)
|
| except Exception:
|
| pass
|
|
|
| training_args = SFTConfig(
|
| output_dir=self.tmp_dir,
|
| logging_strategy="no",
|
| report_to="none",
|
| per_device_train_batch_size=2,
|
| max_steps=2,
|
| packing=packing,
|
| max_length=self.max_length,
|
| use_liger_kernel=True,
|
| )
|
|
|
| trainer = SFTTrainer(
|
| model_name,
|
| args=training_args,
|
| train_dataset=self.train_dataset,
|
| eval_dataset=self.eval_dataset,
|
| )
|
|
|
|
|
| try:
|
| trainer.train()
|
| release_memory(trainer.model, trainer)
|
| finally:
|
| cleanup_liger_patches(trainer)
|
|
|
| @pytest.mark.parametrize("packing", [True, False])
|
| @pytest.mark.parametrize(
|
| "model_name",
|
| [
|
| "trl-internal-testing/tiny-LlamaForCausalLM-3.2",
|
| "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
| ],
|
| )
|
| @require_torch_accelerator
|
| def test_train_offloading(self, model_name, packing):
|
| """Test that activation offloading works with SFTTrainer."""
|
|
|
| training_args = SFTConfig(
|
| output_dir=self.tmp_dir,
|
| activation_offloading=True,
|
| report_to="none",
|
| per_device_train_batch_size=2,
|
| max_steps=2,
|
| packing=packing,
|
| max_length=self.max_length,
|
| )
|
| trainer = SFTTrainer(
|
| model=model_name, args=training_args, train_dataset=self.train_dataset, eval_dataset=self.eval_dataset
|
| )
|
|
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
|
|
| release_memory(trainer.model, trainer)
|
|
|