# Copyright 2020-2026 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import gc import json import pathlib from unittest.mock import MagicMock, patch import pytest import torch import transformers from accelerate.utils.memory import release_memory from datasets import load_dataset from packaging.version import Version from packaging.version import parse as parse_version from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments from transformers.testing_utils import backend_empty_cache, torch_device from transformers.utils import is_peft_available from trl import SFTConfig, SFTTrainer from trl.trainer.sft_trainer import DataCollatorForLanguageModeling, dft_loss from .testing_utils import ( TrlTestCase, ignore_warnings, require_ampere_or_newer, require_bitsandbytes, require_kernels, require_liger_kernel, require_peft, require_torch_accelerator, require_torch_multi_accelerator, require_vision, ) if is_peft_available(): import peft from peft import ( LoraConfig, PeftModel, PrefixTuningConfig, PromptEncoderConfig, PromptTuningConfig, TaskType, get_peft_model, ) class TestDFTLoss(TrlTestCase): def test_dft_loss(self): batch_size = 2 seq_len = 3 vocab_size = 2 # All tokens have the same probability logits = torch.fill(torch.empty(batch_size, seq_len, vocab_size), torch.rand(1).item()) outputs = MagicMock() outputs.logits = logits labels = torch.tensor([[1, 0, 0], [0, 1, -100]]) ce_loss = torch.nn.functional.cross_entropy( logits.view(-1, vocab_size), labels.view(-1), ignore_index=-100, reduction="mean" ) # We need to account for the logits shift operation so we don't consider the first tokens # in each row of the batch num_items_in_batch = 3 # Dft loss predicted_dft_loss = dft_loss(outputs, labels, num_items_in_batch) # If we have just two tokens in our vocab and all logits are the same, # dft scales the ce_loss per token by 0.5. So the dft_loss should be ce_loss/2 torch.testing.assert_close(ce_loss / 2.0, predicted_dft_loss, atol=1e-4, rtol=1e-4) class TestDataCollatorForLanguageModeling(TrlTestCase): def test_basic_padding(self): """Test basic padding functionality without completion masks.""" collator = DataCollatorForLanguageModeling(pad_token_id=0) examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}] result = collator(examples) assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]])) torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]])) def test_completion_mask(self): """Test completion mask functionality.""" collator = DataCollatorForLanguageModeling(pad_token_id=0) examples = [ {"input_ids": [1, 2, 3], "completion_mask": [0, 1, 1]}, {"input_ids": [4, 5], "completion_mask": [0, 1]}, ] result = collator(examples) assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]])) torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3], [-100, 5, -100]])) def test_completion_only_loss_disabled(self): """Test behavior when completion_only_loss is disabled.""" collator = DataCollatorForLanguageModeling(pad_token_id=0, completion_only_loss=False) examples = [ {"input_ids": [1, 2, 3], "completion_mask": [0, 1, 1]}, {"input_ids": [4, 5], "completion_mask": [0, 1]}, ] result = collator(examples) # Labels should not be masked when completion_only_loss=False assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]])) torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]])) def test_padding_free_mode(self): """Test padding-free mode where sequences are concatenated.""" collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True) examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}] result = collator(examples) assert set(result.keys()) == {"input_ids", "position_ids", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5]])) torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1]])) torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3, -100, 5]])) def test_padding_free_with_completion_mask(self): """Test padding-free mode with completion masks.""" collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True) examples = [ {"input_ids": [1, 2, 3], "completion_mask": [0, 0, 1]}, {"input_ids": [4, 5], "completion_mask": [1, 1]}, ] result = collator(examples) assert set(result.keys()) == {"input_ids", "position_ids", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5]])) torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1]])) torch.testing.assert_close(result["labels"], torch.tensor([[-100, -100, 3, -100, 5]])) def test_packing(self): """Test that when using packing with position_ids, attention_mask is dropped with fa2.""" collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True) # Simulate packed sequences with position_ids that restart (typical of BFD packing) examples = [ {"input_ids": [1, 2, 3, 4, 5, 6], "seq_lengths": [3, 3]}, {"input_ids": [7, 8, 9, 10, 11], "seq_lengths": [4, 1]}, ] result = collator(examples) assert set(result.keys()) == {"input_ids", "position_ids", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]])) torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 0]])) torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3, -100, 5, 6, -100, 8, 9, 10, -100]])) def test_pad_to_multiple_of(self): """Test padding to multiple of specified value.""" collator = DataCollatorForLanguageModeling(pad_token_id=0, pad_to_multiple_of=4) examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}] result = collator(examples) assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 0], [4, 5, 0, 0]])) torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0]])) torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, -100], [4, 5, -100, -100]])) def test_pad_to_multiple_of_and_padding_free(self): """Test padding to multiple of specified value.""" collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True, pad_to_multiple_of=4) examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}] result = collator(examples) assert set(result.keys()) == {"input_ids", "position_ids", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5, 0, 0, 0]])) torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1, 0, 0, 0]])) torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3, -100, 5, -100, -100, -100]])) def test_custom_position_ids_but_no_padding_free(self): """Test that custom position_ids are ignored if padding_free is False.""" collator = DataCollatorForLanguageModeling(pad_token_id=0) examples = [{"input_ids": [1, 2, 3], "seq_lengths": [1, 2]}, {"input_ids": [4, 5], "seq_lengths": [2]}] result = collator(examples) assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]])) torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]])) def test_single_example(self): """Test collator with a single example.""" collator = DataCollatorForLanguageModeling(pad_token_id=0) examples = [{"input_ids": [1, 2, 3, 4]}] result = collator(examples) assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4]])) torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1]])) torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, 4]])) def test_different_pad_token_id(self): """Test with different pad token ID.""" collator = DataCollatorForLanguageModeling(pad_token_id=999) examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}] result = collator(examples) assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 999]])) torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]])) def test_assistant_masks(self): """Test handling of assistant masks in examples.""" collator = DataCollatorForLanguageModeling(pad_token_id=0) examples = [ {"input_ids": [1, 2, 3], "assistant_masks": [0, 1, 1]}, {"input_ids": [4, 5], "assistant_masks": [0, 1]}, ] result = collator(examples) torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]])) torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3], [-100, 5, -100]])) def test_max_length_keep_start(self): """Test that sequences longer than max_length are truncated from the start.""" collator = DataCollatorForLanguageModeling(pad_token_id=0, max_length=3) examples = [{"input_ids": [1, 2, 3, 4, 5]}, {"input_ids": [6, 7, 8]}] result = collator(examples) assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [6, 7, 8]])) torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 1]])) torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [6, 7, 8]])) def test_max_length_keep_end(self): """Test that sequences longer than max_length are truncated from the end (keeping last tokens).""" collator = DataCollatorForLanguageModeling(pad_token_id=0, max_length=3, truncation_mode="keep_end") examples = [{"input_ids": [1, 2, 3, 4, 5]}, {"input_ids": [6, 7, 8]}] result = collator(examples) assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[3, 4, 5], [6, 7, 8]])) torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 1]])) torch.testing.assert_close(result["labels"], torch.tensor([[3, 4, 5], [6, 7, 8]])) def test_max_length_no_truncation_needed(self): """Test that max_length larger than sequences does not alter the output.""" collator = DataCollatorForLanguageModeling(pad_token_id=0, max_length=10) examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}] result = collator(examples) assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]])) torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]])) def test_max_length_with_completion_mask(self): """Test that truncation is applied correctly when completion masks are present.""" collator = DataCollatorForLanguageModeling(pad_token_id=0, max_length=3) examples = [ {"input_ids": [1, 2, 3, 4, 5], "completion_mask": [0, 0, 1, 1, 1]}, {"input_ids": [6, 7, 8], "completion_mask": [0, 1, 1]}, ] result = collator(examples) assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [6, 7, 8]])) torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 1]])) torch.testing.assert_close(result["labels"], torch.tensor([[-100, -100, 3], [-100, 7, 8]])) def test_max_length_keep_end_with_completion_mask(self): """Test keep_end truncation with completion masks preserves the final tokens.""" collator = DataCollatorForLanguageModeling(pad_token_id=0, max_length=3, truncation_mode="keep_end") examples = [ {"input_ids": [1, 2, 3, 4, 5], "completion_mask": [0, 0, 1, 1, 1]}, {"input_ids": [6, 7, 8], "completion_mask": [0, 1, 1]}, ] result = collator(examples) assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[3, 4, 5], [6, 7, 8]])) torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 1]])) torch.testing.assert_close(result["labels"], torch.tensor([[3, 4, 5], [-100, 7, 8]])) def test_max_length_invalid_truncation_mode(self): """Test that an invalid truncation_mode raises ValueError.""" collator = DataCollatorForLanguageModeling(pad_token_id=0, max_length=3, truncation_mode="invalid") examples = [{"input_ids": [1, 2, 3, 4, 5]}] with pytest.raises(ValueError, match="Unsupported truncation mode"): collator(examples) def test_single_example_single_doc(self): batch_seq_lengths = [[5]] result = DataCollatorForLanguageModeling.get_position_ids_from_packed_seq_lengths(batch_seq_lengths) assert len(result) == 1 assert torch.equal(result[0], torch.arange(5)) def test_single_example_multiple_docs(self): batch_seq_lengths = [[3, 2]] result = DataCollatorForLanguageModeling.get_position_ids_from_packed_seq_lengths(batch_seq_lengths) assert len(result) == 1 # First sequence: 0, 1, 2; second sequence: 0, 1 assert torch.equal(result[0], torch.tensor([0, 1, 2, 0, 1])) def test_multiple_examples(self): batch_seq_lengths = [[2, 2], [3]] result = DataCollatorForLanguageModeling.get_position_ids_from_packed_seq_lengths(batch_seq_lengths) assert len(result) == 2 assert torch.equal(result[0], torch.tensor([0, 1, 0, 1])) assert torch.equal(result[1], torch.arange(3)) class TestSFTTrainer(TrlTestCase): def test_init_with_training_arguments(self): dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") args = TrainingArguments(output_dir=self.tmp_dir, report_to="none") SFTTrainer(model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=args, train_dataset=dataset) @pytest.mark.parametrize( "model_id", [ "trl-internal-testing/tiny-Cohere2ForCausalLM", pytest.param( "trl-internal-testing/tiny-Glm4MoeForCausalLM", marks=pytest.mark.skipif( Version(transformers.__version__) < Version("5.0.0"), reason="GLM4 tokenizer requires transformers>=5.0.0", ), ), "trl-internal-testing/tiny-GptOssForCausalLM", "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", "trl-internal-testing/tiny-Qwen3MoeForCausalLM", ], ) def test_train(self, model_id): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") # Initialize the trainer training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none") trainer = SFTTrainer(model=model_id, args=training_args, train_dataset=dataset) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" # Special case for harmony def test_train_gpt_oss(self): # Get the dataset dataset = load_dataset("trl-internal-testing/harmony", "language_modeling", split="train") # Initialize the trainer training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none") trainer = SFTTrainer( model="trl-internal-testing/tiny-GptOssForCausalLM", args=training_args, train_dataset=dataset ) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_model(self): # Instantiate the model model = AutoModelForCausalLM.from_pretrained( "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", dtype="float32", ) # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") # Initialize the trainer training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none") trainer = SFTTrainer(model=model, args=training_args, train_dataset=dataset) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_dft_loss(self): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling") # Initialize the trainer training_args = SFTConfig( output_dir=self.tmp_dir, loss_type="dft", learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates report_to="none", eval_strategy="steps", eval_steps=3, ) trainer = SFTTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset["train"], eval_dataset=dataset["test"], ) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_moe_model_with_aux_loss(self): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") # Initialize the trainer training_args = SFTConfig( output_dir=self.tmp_dir, report_to="none", model_init_kwargs={"output_router_logits": True}, ) trainer = SFTTrainer( model="trl-internal-testing/tiny-Qwen3MoeForCausalLM", args=training_args, train_dataset=dataset ) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss and aux loss are not None assert trainer.state.log_history[-1]["train_loss"] is not None assert trainer.state.log_history[-1]["aux_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_with_formatting_func(self): # Dummy formatting function def formatting_prompts_func(example): chosen, rejected = example["chosen"], example["rejected"] return f"### Chosen: {chosen}\n### Rejected: {rejected}" # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") # Initialize the trainer training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none") trainer = SFTTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset, formatting_func=formatting_prompts_func, ) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_model_dtype(self): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") # Initialize the trainer training_args = SFTConfig( output_dir=self.tmp_dir, model_init_kwargs={"dtype": torch.float16}, learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates report_to="none", ) trainer = SFTTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset ) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): # For some reasonn model.layers.0.input_layernorm.weight doesn't change in GitHub Actions but does # locally. We ignore this parameter for now if "layernorm" in n: continue new_param = trainer.model.get_parameter(n) # Check the torch dtype assert new_param.dtype == torch.float16 assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @require_peft def test_train_dense_with_peft_config_lora(self): # Get the base model parameter names model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32") base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") # Initialize the trainer training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none") trainer = SFTTrainer( model=model_id, args=training_args, train_dataset=dataset, peft_config=LoraConfig(), ) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model parameters to be the same torch.testing.assert_close(param, new_param), f"Parameter {n} has changed" elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @pytest.mark.parametrize( "peft_type", [ "prompt_tuning", "prefix_tuning", "prompt_encoder", ], ) @require_peft def test_train_with_peft_config_prompt_tuning(self, peft_type): # Get the base model parameter names model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32") base_param_names = [f"base_model.{n}" for n, _ in model.named_parameters()] # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") # Initialize the trainer, p-tuning doesn't support gradient checkpointing training_args = SFTConfig(bf16=False, output_dir=self.tmp_dir, report_to="none", gradient_checkpointing=False) if peft_type == "prompt_tuning": peft_config = PromptTuningConfig( task_type=TaskType.CAUSAL_LM, num_virtual_tokens=4, tokenizer_name_or_path="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", ) elif peft_type == "prefix_tuning": if parse_version(peft.__version__) <= Version("0.17.1"): pytest.xfail( "Prefix tuning with device_map='auto' is broken in peft 0.17.1 and below. See " "https://github.com/huggingface/peft/issues/2821" ) peft_config = PrefixTuningConfig( task_type=TaskType.CAUSAL_LM, num_virtual_tokens=4, ) elif peft_type == "prompt_encoder": peft_config = PromptEncoderConfig( task_type=TaskType.CAUSAL_LM, num_virtual_tokens=4, encoder_hidden_size=model.config.hidden_size, # This will be overwritten below ) trainer = SFTTrainer( model=model_id, args=training_args, train_dataset=dataset, peft_config=peft_config, ) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model parameters to be the same torch.testing.assert_close(param, new_param), f"Parameter {n} has changed" else: # We expect the peft parameters to be different assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @require_peft def test_train_moe_with_peft_config(self): # Get the base model parameter names model_id = "trl-internal-testing/tiny-GptOssForCausalLM" model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32") base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") # Initialize the trainer training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none") trainer = SFTTrainer( model=model_id, args=training_args, train_dataset=dataset, peft_config=LoraConfig(target_parameters=["mlp.experts.down_proj", "mlp.experts.gate_up_proj"]), ) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model parameters to be the same torch.testing.assert_close(param, new_param), f"Parameter {n} has changed" elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @require_peft def test_train_peft_model(self): # Get the base model model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32") # Get the base model parameter names base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] # Turn the model into a peft model lora_config = LoraConfig() model = get_peft_model(model, lora_config) # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") # Initialize the trainer training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none") trainer = SFTTrainer(model=model, args=training_args, train_dataset=dataset) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model parameters to be the same torch.testing.assert_close(param, new_param), f"Parameter {n} has changed" elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" # In practice, this test is the same as `test_train_dense_with_peft_config_lora`, since gradient checkpointing is # enabled by default in `SFTTrainer`. We keep it as a regression guard: if the default ever changes, we still # explicitly test PEFT + gradient checkpointing, which has caused issues in the past. @require_peft def test_train_with_peft_config_and_gradient_checkpointing(self): # Get the base model parameter names model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32") base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") # Initialize the trainer training_args = SFTConfig(output_dir=self.tmp_dir, gradient_checkpointing=True, report_to="none") trainer = SFTTrainer( model=model_id, args=training_args, train_dataset=dataset, peft_config=LoraConfig(), ) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model parameters to be the same torch.testing.assert_close(param, new_param), f"Parameter {n} has changed" elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @pytest.mark.parametrize("use_reentrant", [True, False]) @require_peft def test_train_with_peft_config_and_gradient_checkpointing_reentrant(self, use_reentrant): # Get the base model parameter names model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32") base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") # Initialize the trainer training_args = SFTConfig( output_dir=self.tmp_dir, gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": use_reentrant}, report_to="none", ) trainer = SFTTrainer( model=model_id, args=training_args, train_dataset=dataset, peft_config=LoraConfig(), ) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model parameters to be the same torch.testing.assert_close(param, new_param), f"Parameter {n} has changed" elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @require_liger_kernel def test_train_with_liger(self): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") # Initialize the trainer training_args = SFTConfig(output_dir=self.tmp_dir, use_liger_kernel=True, report_to="none") trainer = SFTTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset ) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @require_torch_accelerator @require_liger_kernel def test_compute_loss_skip_logits_on_eval_without_metrics_with_liger(self): dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train[:1]") training_args = SFTConfig( output_dir=self.tmp_dir, use_liger_kernel=False, report_to="none", max_length=8, bf16=False, ) trainer = SFTTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset, compute_metrics=None, ) trainer.args.use_liger_kernel = True trainer.model.eval() captured = {} def mock_super_compute_loss(model, inputs, return_outputs=False, num_items_in_batch=None): captured["skip_logits"] = inputs.get("skip_logits") dummy_loss = torch.tensor(1.0, requires_grad=True) dummy_outputs = MagicMock() dummy_outputs.token_accuracy = None dummy_outputs.logits = torch.randn(1, 5, trainer.model.config.vocab_size) return (dummy_loss, dummy_outputs) inputs = { "input_ids": torch.tensor([[1, 2, 3, 4, 5]]), "labels": torch.tensor([[1, 2, 3, 4, 5]]), "attention_mask": torch.tensor([[1, 1, 1, 1, 1]]), } with patch("transformers.Trainer.compute_loss", side_effect=mock_super_compute_loss): trainer.compute_loss(trainer.model, inputs) assert captured["skip_logits"] is True @require_torch_accelerator @require_liger_kernel def test_predict_does_not_skip_logits_with_liger(self): dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train[:1]") training_args = SFTConfig( output_dir=self.tmp_dir, use_liger_kernel=False, report_to="none", max_length=8, bf16=False, ) trainer = SFTTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset, compute_metrics=None, ) trainer.args.use_liger_kernel = True trainer.model.eval() captured = {} def mock_super_compute_loss(model, inputs, return_outputs=False, num_items_in_batch=None): captured["skip_logits"] = inputs.get("skip_logits") dummy_loss = torch.tensor(1.0, requires_grad=True) dummy_outputs = (dummy_loss, torch.randn(1, 5, trainer.model.config.vocab_size)) return (dummy_loss, dummy_outputs) with patch("transformers.Trainer.compute_loss", side_effect=mock_super_compute_loss): trainer.predict(trainer.train_dataset) assert captured["skip_logits"] is False def test_train_with_non_chatml_conversational_data(self): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train") # Rename role/content to from/value to ensure SFT works with non-chatML conversational data def rename_fields(example: list[dict]): return {"conversations": [{"from": m["role"], "value": m["content"]} for m in example["messages"]]} dataset = dataset.map(rename_fields, remove_columns="messages") # Initialize the trainer training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none") trainer = SFTTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset ) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_with_pretokenized_data(self): # Get the dataset model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" tokenizer = AutoTokenizer.from_pretrained(model_id) dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") def tokenize_example(example): return tokenizer(example["text"]) # Apply tokenization tokenized_dataset = dataset.map(tokenize_example, remove_columns=["text"]) # Initialize the trainer training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none") trainer = SFTTrainer(model=model_id, args=training_args, train_dataset=tokenized_dataset) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_skip_prepare_dataset_passes_truncation_to_text_collator(self): dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train[:2]") with pytest.warns(FutureWarning, match="keep_end.*deprecated"): training_args = SFTConfig( output_dir=self.tmp_dir, max_length=16, truncation_mode="keep_end", dataset_kwargs={"skip_prepare_dataset": True}, report_to="none", ) trainer = SFTTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset ) assert isinstance(trainer.data_collator, DataCollatorForLanguageModeling) assert trainer.data_collator.max_length == 16 assert trainer.data_collator.truncation_mode == "keep_end" def test_padding_free_without_packing_and_max_length_raises(self): dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train[:2]") training_args = SFTConfig( output_dir=self.tmp_dir, max_length=16, padding_free=True, report_to="none", ) with pytest.raises(ValueError, match="`max_length` is not enforced"): SFTTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset ) def test_train_with_iterable_dataset(self): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train", streaming=True) # Initialize the trainer training_args = SFTConfig(output_dir=self.tmp_dir, max_steps=3, report_to="none") trainer = SFTTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset ) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @require_kernels @require_ampere_or_newer # Flash attention 2 requires Ampere or newer GPUs def test_train_padding_free(self): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") # Initialize the trainer training_args = SFTConfig( output_dir=self.tmp_dir, padding_free=True, model_init_kwargs={"attn_implementation": "kernels-community/flash-attn2"}, bf16=True, # flash_attention_2 only supports bf16 and fp16 report_to="none", ) trainer = SFTTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset ) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @pytest.mark.parametrize("packing_strategy", ["bfd", "wrapped"]) @ignore_warnings(message="You are using packing, but the attention implementation is not.*", category=UserWarning) @ignore_warnings(message="Padding-free training is enabled, but the attention.*", category=UserWarning) def test_train_packing(self, packing_strategy): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") # Initialize the trainer training_args = SFTConfig( output_dir=self.tmp_dir, packing=True, packing_strategy=packing_strategy, max_length=10, report_to="none" ) trainer = SFTTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset ) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @ignore_warnings(message="You are using packing, but the attention implementation is not.*", category=UserWarning) @ignore_warnings(message="Padding-free training is enabled, but the attention.*", category=UserWarning) def test_eval_packing(self): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling") # Initialize the trainer training_args = SFTConfig( output_dir=self.tmp_dir, packing=True, max_length=64, report_to="none", ) trainer = SFTTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset["train"], eval_dataset=dataset["test"], ) # Check the number of sequences in train and eval datasets num_train_seqs = sum(len(x) for x in trainer.train_dataset["seq_lengths"]) num_eval_seqs = sum(len(x) for x in trainer.eval_dataset["seq_lengths"]) assert num_train_seqs == 17 # we should still have 17 seqs assert num_eval_seqs == 2 # we should still have 2 seqs # Check that all sequences are shorter than the max length assert all(sum(x) <= 64 for x in trainer.train_dataset["seq_lengths"]) assert all(sum(x) <= 64 for x in trainer.eval_dataset["seq_lengths"]) # Check the number of sequences in train and eval datasets assert len(trainer.train_dataset["input_ids"]) == 3 # w/ this dataset, we end up with 46 seqs assert len(trainer.eval_dataset["input_ids"]) == 1 # w/ this dataset, we end up with 6 seqs @ignore_warnings(message="You are using packing, but the attention implementation is not.*", category=UserWarning) @ignore_warnings(message="Padding-free training is enabled, but the attention.*", category=UserWarning) def test_only_train_packing(self): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling") # Initialize the trainer training_args = SFTConfig( output_dir=self.tmp_dir, packing=True, eval_packing=False, max_length=64, report_to="none", ) trainer = SFTTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset["train"], eval_dataset=dataset["test"], ) # Check the number of sequences in train dataset num_train_seqs = sum(len(x) for x in trainer.train_dataset["seq_lengths"]) assert num_train_seqs == 17 # we should still have 17 seqs # We expect eval dataset not having "seq_lengths" as eval_packing is False assert "seq_lengths" not in trainer.eval_dataset # Check that all sequences are shorter than the max length assert all(sum(x) <= 64 for x in trainer.train_dataset["seq_lengths"]) # Check the number of sequences in train and eval datasets assert len(trainer.train_dataset["input_ids"]) == 3 # w/ this dataset, we end up with 46 seqs assert len(trainer.eval_dataset["input_ids"]) == 2 # w/ this dataset, we end up with 6 seqs def test_train_with_chat_template_kwargs(self): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train") # Initialize the trainer training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none") tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") # The following template is a simplified version of the Qwen chat template, where an additional argument # `role_capital` is used to control the capitalization of roles. tokenizer.chat_template = '{%- if messages[0]["role"] == "system" -%} {{ "<|im_start|>" + ("SYSTEM" if role_capital else "system") + "\\n" + messages[0]["content"] + "<|im_end|>\\n" }}{%- else -%} {{ "<|im_start|>" + ("SYSTEM" if role_capital else "system") + "\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n" }}{%- endif -%}{%- for message in messages -%} {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) -%} {{ "<|im_start|>" + (message.role.upper() if role_capital else message.role) + "\\n" + message.content + "<|im_end|>\\n" }} {%- elif message.role == "assistant" -%} {{ "<|im_start|>" + ("ASSISTANT" if role_capital else "assistant") }} {%- if message.content -%} {{ "\\n" + message.content }} {%- endif -%} {{ "<|im_end|>\\n" }} {%- elif message.role == "tool" -%} {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") -%} {{ "<|im_start|>" + ("USER" if role_capital else "user") }} {%- endif -%} {{ "\\n\\n" + message.content + "\\n" }} {%- if loop.last or (messages[loop.index0 + 1].role != "tool") -%} {{ "<|im_end|>\\n" }} {%- endif -%} {%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%} {{ "<|im_start|>" + ("ASSISTANT" if role_capital else "assistant") + "\\n" }}{%- endif -%}' dataset = dataset.add_column( "chat_template_kwargs", [{"role_capital": bool(i % 2)} for i in range(len(dataset))] ) assert "chat_template_kwargs" in dataset.features trainer = SFTTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset, processing_class=tokenizer, ) # Assert trainer uses the same chat template as tokenizer assert trainer.processing_class.chat_template == tokenizer.chat_template # Assert chat_template is applied for i in range(2): role = "SYSTEM" if i else "system" system_prompt = ( f"<|im_start|>{role}\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>" ) system_prompt_ids = trainer.processing_class(system_prompt)["input_ids"] assert trainer.train_dataset[i]["input_ids"][: len(system_prompt_ids)] == system_prompt_ids # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_assistant_only(self): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train") # Initialize the trainer training_args = SFTConfig(output_dir=self.tmp_dir, assistant_only_loss=True, report_to="none") trainer = SFTTrainer( model="trl-internal-testing/tiny-Qwen3ForCausalLM", args=training_args, train_dataset=dataset ) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_completion_only(self): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_completion", split="train") # Initialize the trainer training_args = SFTConfig(output_dir=self.tmp_dir, completion_only_loss=True, report_to="none") trainer = SFTTrainer( model="trl-internal-testing/tiny-Qwen3ForCausalLM", args=training_args, train_dataset=dataset ) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_completion_only_harmony(self): # Get the dataset dataset = load_dataset("trl-internal-testing/harmony", "prompt_completion", split="train") # Initialize the trainer training_args = SFTConfig(output_dir=self.tmp_dir, completion_only_loss=True, report_to="none") trainer = SFTTrainer( model="trl-internal-testing/tiny-GptOssForCausalLM", args=training_args, train_dataset=dataset ) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_assistant_only_and_completion_only(self): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_completion", split="train") # To test this case, we need to add user messages in the completion (they'll be masked in the loss) def add_to_completion(example): example["completion"].append(example["prompt"][0]) example["completion"].append(example["completion"][0]) return example dataset = dataset.map(add_to_completion) # Initialize the trainer training_args = SFTConfig( output_dir=self.tmp_dir, assistant_only_loss=True, completion_only_loss=True, report_to="none" ) trainer = SFTTrainer( model="trl-internal-testing/tiny-Qwen3ForCausalLM", args=training_args, train_dataset=dataset ) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_assistant_only_iterable_dataset(self): # Get the dataset dataset = load_dataset( "trl-internal-testing/zen", "conversational_language_modeling", split="train", streaming=True ) # Initialize the trainer training_args = SFTConfig(output_dir=self.tmp_dir, assistant_only_loss=True, max_steps=3, report_to="none") trainer = SFTTrainer( model="trl-internal-testing/tiny-Qwen3ForCausalLM", args=training_args, train_dataset=dataset ) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_with_set_chat_template_from_model(self): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train") # Initialize the trainer training_args = SFTConfig(output_dir=self.tmp_dir, chat_template_path="Qwen/Qwen3-4B", report_to="none") # trl-internal-testing/tiny-GPTNeoXForCausalLM doesn't have a chat template set by default trainer = SFTTrainer( model="trl-internal-testing/tiny-GPTNeoXForCausalLM", args=training_args, train_dataset=dataset ) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_with_set_chat_template_from_path(self, lazy_shared_datadir): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train") # Initialize the trainer training_args = SFTConfig( output_dir=self.tmp_dir, chat_template_path=str(lazy_shared_datadir / "template.jinja"), report_to="none", ) # trl-internal-testing/tiny-GPTNeoXForCausalLM doesn't have a chat template set by default trainer = SFTTrainer( model="trl-internal-testing/tiny-GPTNeoXForCausalLM", args=training_args, train_dataset=dataset ) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" # Check that the template saved in the output directory is the same as the one used for training template_path = pathlib.Path(self.tmp_dir) / "checkpoint-9" / "chat_template.jinja" assert template_path.exists(), f"Chat template not found at {template_path}" with open(template_path) as f: template_content = f.read() with open(training_args.chat_template_path) as f: original_template_content = f.read() assert template_content == original_template_content, "Chat template content does not match the original" def test_train_toolcall_data(self): # Get the dataset dataset = load_dataset("trl-internal-testing/toolcall", "language_modeling", split="train") # Initialize the trainer training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none") trainer = SFTTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset ) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_toolcall_data_as_json(self): # Tabular backends (Arrow/Parquet) can insert `None` for missing keys in nested structures. # If `tools` is stored as a list of dicts and examples use different dict schemas, nulls may # be introduced and break tool processing. This test ensures we also support `tools` provided # as a list of dicts. # Get the dataset dataset = load_dataset("trl-internal-testing/toolcall", "language_modeling", split="train") def convert_to_json(example): return {"tools": json.loads(example["tools"])} dataset = dataset.map(convert_to_json) # Initialize the trainer training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none") trainer = SFTTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset ) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_with_eval(self): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling") # Initialize the trainer training_args = SFTConfig(output_dir=self.tmp_dir, eval_strategy="steps", eval_steps=3, report_to="none") trainer = SFTTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset["train"], eval_dataset=dataset["test"], ) # Train the model trainer.train() # Check that the eval loss is not None assert trainer.state.log_history[0]["eval_loss"] is not None def test_train_with_multiple_eval_dataset(self): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling") # Initialize the trainer training_args = SFTConfig(output_dir=self.tmp_dir, eval_strategy="steps", eval_steps=3, report_to="none") trainer = SFTTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset["train"], eval_dataset={"data1": dataset["test"], "data2": dataset["test"]}, ) # Train the model trainer.train() # Check that the eval losses are not None assert trainer.state.log_history[-3]["eval_data1_loss"] is not None assert trainer.state.log_history[-2]["eval_data2_loss"] is not None def test_train_with_compute_metrics(self): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling") def dummy_compute_metrics(eval_pred): return {"my_metric": 0.123} # Initialize the trainer training_args = SFTConfig( output_dir=self.tmp_dir, eval_strategy="steps", eval_steps=3, report_to="none", ) trainer = SFTTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset["train"], eval_dataset=dataset["test"], compute_metrics=dummy_compute_metrics, ) # Train the model trainer.train() # Check that the custom metric is logged assert trainer.state.log_history[-2]["eval_my_metric"] == 0.123 # In practice, this test is the same as `test_train`, since gradient checkpointing is enabled by default in # `SFTTrainer`. We keep it as a regression guard: if the default ever changes, we still explicitly test gradient # checkpointing, which has caused issues in the past. def test_train_with_gradient_checkpointing(self): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") # Initialize the trainer training_args = SFTConfig(output_dir=self.tmp_dir, gradient_checkpointing=True, report_to="none") trainer = SFTTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset ) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @pytest.mark.parametrize("use_reentrant", [True, False]) def test_train_with_gradient_checkpointing_reentrant(self, use_reentrant): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") # Initialize the trainer training_args = SFTConfig( output_dir=self.tmp_dir, gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": use_reentrant}, report_to="none", ) trainer = SFTTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset ) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_tag_added(self): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") # Initialize the trainer trainer = SFTTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", train_dataset=dataset, ) for tag in ["sft", "trl"]: assert tag in trainer.model.model_tags @require_peft def test_tag_added_peft(self): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") # Initialize the trainer trainer = SFTTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", train_dataset=dataset, peft_config=LoraConfig(), ) for tag in ["sft", "trl"]: assert tag in trainer.model.model_tags @pytest.mark.parametrize( "model_id", [ "trl-internal-testing/tiny-Gemma3ForConditionalGeneration", pytest.param( "trl-internal-testing/tiny-Gemma4ForConditionalGeneration", marks=pytest.mark.skipif( Version(transformers.__version__) < Version("5.5.0"), reason="Gemma4 models were introduced in transformers-5.5.0", ), ), # "trl-internal-testing/tiny-Idefics2ForConditionalGeneration", high memory peak, skipped for now # "trl-internal-testing/tiny-Idefics3ForConditionalGeneration", high memory peak, skipped for now "trl-internal-testing/tiny-LlavaForConditionalGeneration", "trl-internal-testing/tiny-LlavaNextForConditionalGeneration", "trl-internal-testing/tiny-Qwen2VLForConditionalGeneration", "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", # "trl-internal-testing/tiny-SmolVLMForConditionalGeneration", seems not to support bf16 properly pytest.param( "trl-internal-testing/tiny-Qwen3VLForConditionalGeneration", marks=[ pytest.mark.skipif( Version(transformers.__version__) < Version("4.57.0"), reason="Qwen3-VL series were introduced in transformers-4.57.0", ), ], ), pytest.param( "trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", marks=pytest.mark.skipif( Version(transformers.__version__) < Version("5.2.0"), reason="Qwen3.5 models were introduced in transformers-5.2.0", ), ), ], ) @require_vision def test_train_vlm(self, model_id): # Get the dataset dataset = load_dataset("trl-internal-testing/zen-image", "conversational_language_modeling", split="train") # Initialize the trainer training_args = SFTConfig( output_dir=self.tmp_dir, max_length=None, # for VLMs, truncating can remove image tokens, leading to errors report_to="none", ) trainer = SFTTrainer(model=model_id, args=training_args, train_dataset=dataset) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) # For some reason, these params are not updated. This is probably not related to TRL, but to # the model itself. We should investigate this further, but for now we just skip these params. # fmt: off if ( model_id == "trl-internal-testing/tiny-Gemma3ForConditionalGeneration" and "model.vision_tower.vision_model.head" in n or model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and "model.vision_tower.vision_model.post_layernorm" in n or # transformers < 5.6.0 model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and "vision_tower.vision_model.encoder.layers.1" in n or # transformers < 5.6.0 model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and "model.vision_tower.encoder.layers.1" in n or # transformers >= 5.6.0, see #5497 model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and "model.vision_tower.post_layernorm" in n or # transformers >= 5.6.0, see #5497 model_id == "trl-internal-testing/tiny-LlavaNextForConditionalGeneration" and "model.vision_tower.vision_model.post_layernorm" in n or # transformers < 5.6.0 model_id == "trl-internal-testing/tiny-LlavaNextForConditionalGeneration" and "vision_tower.vision_model.encoder.layers.1" in n or # transformers < 5.6.0 model_id == "trl-internal-testing/tiny-LlavaNextForConditionalGeneration" and "model.vision_tower.encoder.layers.1" in n or # transformers >= 5.6.0, see #5497 model_id == "trl-internal-testing/tiny-LlavaNextForConditionalGeneration" and "model.vision_tower.post_layernorm" in n or # transformers >= 5.6.0, see #5497 model_id == "trl-internal-testing/tiny-Qwen3VLForConditionalGeneration" and "model.visual.deepstack_merger_list" in n ): # fmt: on continue assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated" @pytest.mark.parametrize( "model_id", [ "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", ], ) @pytest.mark.xfail( parse_version(transformers.__version__) < parse_version("4.57.0"), reason="Mixing text-only and image+text examples is only supported in transformers >= 4.57.0", strict=False, ) @require_vision def test_train_vlm_multi_image(self, model_id): # Get the dataset dataset = load_dataset( "trl-internal-testing/zen-multi-image", "conversational_prompt_completion", split="train" ) # Initialize the trainer training_args = SFTConfig( output_dir=self.tmp_dir, learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates max_length=None, # for VLMs, truncating can remove image tokens, leading to errors report_to="none", ) trainer = SFTTrainer( model=model_id, args=training_args, train_dataset=dataset, ) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated" @pytest.mark.parametrize( "model_id", [ "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", # Special case for Gemma, as it uses token_type_ids, and we need to ensure they are properly in the collator: "trl-internal-testing/tiny-Gemma3ForConditionalGeneration", pytest.param( "trl-internal-testing/tiny-Gemma4ForConditionalGeneration", marks=pytest.mark.skipif( Version(transformers.__version__) < Version("5.5.0"), reason="Gemma4 models were introduced in transformers-5.5.0", ), ), ], ) @require_vision def test_train_vlm_prompt_completion(self, model_id): # Get the dataset dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_completion", split="train") # Initialize the trainer training_args = SFTConfig( output_dir=self.tmp_dir, learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates max_length=None, # for VLMs, truncating can remove image tokens, leading to errors report_to="none", ) trainer = SFTTrainer( model=model_id, args=training_args, train_dataset=dataset, ) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated" # Gemma 3n uses a timm encoder, making it difficult to create a smaller variant for testing. # To ensure coverage, we run tests on the full model but mark them as slow to exclude from default runs. @pytest.mark.slow @require_vision @pytest.mark.skip(reason="Model google/gemma-3n-E2B-it is gated and requires HF token") def test_train_vlm_gemma_3n(self): # Get the dataset dataset = load_dataset("trl-internal-testing/zen-image", "conversational_language_modeling", split="train") # Initialize the trainer training_args = SFTConfig( output_dir=self.tmp_dir, learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates max_length=None, # for VLMs, truncating can remove image tokens, leading to errors per_device_train_batch_size=1, # VLM training is memory intensive, reduce batch size to avoid OOM model_init_kwargs={"dtype": "bfloat16"}, report_to="none", ) trainer = SFTTrainer(model="google/gemma-3n-E2B-it", args=training_args, train_dataset=dataset) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if "model.audio_tower" in n or "model.embed_audio" in n: # The audio embedding parameters are not updated because this dataset contains no audio data continue assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated" @pytest.mark.parametrize( "model_id", [ "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", ], ) @pytest.mark.parametrize( "dataset_config", [ "conversational_language_modeling", "conversational_prompt_completion", "standard_language_modeling", # Regression test for #5334 "standard_prompt_completion", ], ) @require_vision def test_train_vlm_text_only_data(self, model_id, dataset_config): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", dataset_config, split="train") # Initialize the trainer training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none") trainer = SFTTrainer( model=model_id, args=training_args, train_dataset=dataset, ) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n.startswith("model.visual"): torch.testing.assert_close(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is updated" else: assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated" @require_peft def test_prompt_tuning(self): """Test that SFT works with Prompt Tuning.""" dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none") trainer = SFTTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset, peft_config=PromptEncoderConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=8), ) # Save initial parameters to check they change during training previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} trainer.train() # Check that training completed successfully assert trainer.state.log_history[-1]["train_loss"] is not None assert trainer.state.log_history[-1]["mean_token_accuracy"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if "base_model" in n: # We expect the base model parameters to be the same torch.testing.assert_close(param, new_param), f"Parameter {n} has changed" elif "prompt_encoder" in n: # We expect the peft parameters to be different assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" else: raise ValueError(f"Unexpected parameter {n} in model: {trainer.model}") @require_peft @require_bitsandbytes def test_peft_with_quantization(self): # Get the base model model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, ) model = AutoModelForCausalLM.from_pretrained( model_id, dtype="float32", quantization_config=quantization_config, ) # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") # Initialize the trainer with the already configured PeftModel training_args = SFTConfig(output_dir=self.tmp_dir, learning_rate=0.1, report_to="none") trainer = SFTTrainer(model=model, args=training_args, train_dataset=dataset, peft_config=LoraConfig()) # Save initial parameters to check they change during training previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} trainer.train() # Check that training completed successfully assert trainer.state.log_history[-1]["train_loss"] is not None assert trainer.state.log_history[-1]["mean_token_accuracy"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) # In bitsandbytes, bias parameters are automatically cast to the input dtype during the forward pass if # their dtype doesn’t match. This causes the module to change unexpectedly during the first forward pass of # the training. To handle this, we cast these specific bias parameters to float32 before comparison. # https://github.com/bitsandbytes-foundation/bitsandbytes/blob/45553f7392e524eacf400b132cfe01261f6477be/bitsandbytes/nn/modules.py#L518 # We still need to investigate why the compute dtype ends up being different than for these parameters. if n in [ "base_model.model.model.layers.1.self_attn.k_proj.bias", "base_model.model.model.layers.1.self_attn.q_proj.base_layer.bias", "base_model.model.model.layers.1.self_attn.v_proj.base_layer.bias", ]: param = param.float() if "lora" not in n: # We expect the base model parameters to be the same torch.testing.assert_close(param, new_param), f"Parameter {n} has changed" elif "lora" in n: # We expect the peft parameters to be different assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" else: raise ValueError(f"Unexpected parameter {n} in model: {trainer.model}") @require_peft def test_prompt_tuning_peft_model(self): """Test that SFT works with Prompt Tuning and a pre-converted PeftModel""" model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", dtype="float32") model = get_peft_model(model, PromptEncoderConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=8)) dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none") trainer = SFTTrainer(model=model, args=training_args, train_dataset=dataset) # Save initial parameters to check they change during training previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} trainer.train() # Check that training completed successfully assert trainer.state.log_history[-1]["train_loss"] is not None assert trainer.state.log_history[-1]["mean_token_accuracy"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if "base_model" in n: # We expect the base model parameters to be the same torch.testing.assert_close(param, new_param), f"Parameter {n} has changed" elif "prompt_encoder" in n: # We expect the peft parameters to be different assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" else: raise ValueError(f"Unexpected parameter {n} in model: {trainer.model}") @pytest.mark.slow @require_torch_accelerator @require_peft class TestSFTTrainerSlow(TrlTestCase): def setup_method(self): self.train_dataset = load_dataset("stanfordnlp/imdb", split="train[:10%]") self.eval_dataset = load_dataset("stanfordnlp/imdb", split="test[:10%]") self.max_length = 128 self.peft_config = LoraConfig( lora_alpha=16, lora_dropout=0.1, r=8, bias="none", task_type="CAUSAL_LM", ) def teardown_method(self): gc.collect() backend_empty_cache(torch_device) gc.collect() @pytest.mark.parametrize("packing", [True, False]) @pytest.mark.parametrize( "model_name", [ "trl-internal-testing/tiny-LlamaForCausalLM-3.2", "trl-internal-testing/tiny-MistralForCausalLM-0.2", ], ) def test_sft_trainer_transformers_mp(self, model_name, packing): """ Simply tests if passing a transformers model to `SFTTrainer` loads and runs the trainer as expected in mixed precision. """ training_args = SFTConfig( output_dir=self.tmp_dir, logging_strategy="no", report_to="none", per_device_train_batch_size=2, max_steps=10, fp16=True, # this is sufficient to enable amp packing=packing, max_length=self.max_length, ) model = AutoModelForCausalLM.from_pretrained(model_name, dtype="float32") tokenizer = AutoTokenizer.from_pretrained(model_name) trainer = SFTTrainer( model, args=training_args, processing_class=tokenizer, train_dataset=self.train_dataset, eval_dataset=self.eval_dataset, ) trainer.train() release_memory(model, trainer) @pytest.mark.parametrize("device_map", [{"": 0}, "auto"]) @pytest.mark.parametrize( "gradient_checkpointing_kwargs", [None, {"use_reentrant": False}, {"use_reentrant": True}] ) @pytest.mark.parametrize("packing", [True, False]) @pytest.mark.parametrize( "model_name", [ "trl-internal-testing/tiny-LlamaForCausalLM-3.2", "trl-internal-testing/tiny-MistralForCausalLM-0.2", ], ) @require_torch_multi_accelerator def test_sft_trainer_transformers_mp_gc_device_map( self, model_name, packing, gradient_checkpointing_kwargs, device_map ): """ Simply tests if passing a transformers model to `SFTTrainer` loads and runs the trainer as expected in mixed precision + different scenarios of gradient_checkpointing (single, multi-gpu, etc). """ training_args = SFTConfig( output_dir=self.tmp_dir, logging_strategy="no", report_to="none", per_device_train_batch_size=2, max_steps=10, packing=packing, max_length=self.max_length, fp16=True, # this is sufficient to enable amp gradient_checkpointing=True, # default, here for clarity gradient_checkpointing_kwargs=gradient_checkpointing_kwargs, ) model = AutoModelForCausalLM.from_pretrained(model_name, dtype="float32", device_map=device_map) tokenizer = AutoTokenizer.from_pretrained(model_name) trainer = SFTTrainer( model, args=training_args, processing_class=tokenizer, train_dataset=self.train_dataset, eval_dataset=self.eval_dataset, ) trainer.train() release_memory(model, trainer) @pytest.mark.parametrize( "gradient_checkpointing_kwargs", [None, {"use_reentrant": False}, {"use_reentrant": True}] ) @pytest.mark.parametrize("packing", [True, False]) @pytest.mark.parametrize( "model_name", [ "trl-internal-testing/tiny-LlamaForCausalLM-3.2", "trl-internal-testing/tiny-MistralForCausalLM-0.2", ], ) @require_peft @require_bitsandbytes def test_sft_trainer_transformers_mp_gc_peft_qlora(self, model_name, packing, gradient_checkpointing_kwargs): """ Simply tests if passing a transformers model + PEFT + bnb to `SFTTrainer` loads and runs the trainer as expected in mixed precision + different scenarios of gradient_checkpointing. """ training_args = SFTConfig( output_dir=self.tmp_dir, logging_strategy="no", report_to="none", per_device_train_batch_size=2, max_steps=10, packing=packing, max_length=self.max_length, gradient_checkpointing=True, # default, here for clarity gradient_checkpointing_kwargs=gradient_checkpointing_kwargs, ) quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16) model = AutoModelForCausalLM.from_pretrained( model_name, dtype="float32", quantization_config=quantization_config ) tokenizer = AutoTokenizer.from_pretrained(model_name) trainer = SFTTrainer( model, args=training_args, processing_class=tokenizer, train_dataset=self.train_dataset, eval_dataset=self.eval_dataset, peft_config=self.peft_config, ) assert isinstance(trainer.model, PeftModel) trainer.train() release_memory(model, trainer) @pytest.mark.parametrize("packing", [True, False]) @pytest.mark.parametrize( "model_name", [ "trl-internal-testing/tiny-LlamaForCausalLM-3.2", "trl-internal-testing/tiny-MistralForCausalLM-0.2", ], ) @require_peft @require_bitsandbytes def test_sft_trainer_with_chat_format_qlora(self, model_name, packing): """ Simply tests if using setup_chat_format with a transformers model + peft + bnb config to `SFTTrainer` loads and runs the trainer as expected. """ train_dataset = load_dataset("trl-internal-testing/dolly-chatml-sft", split="train") training_args = SFTConfig( packing=packing, max_length=self.max_length, output_dir=self.tmp_dir, logging_strategy="no", report_to="none", per_device_train_batch_size=2, max_steps=10, ) quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16) model = AutoModelForCausalLM.from_pretrained( model_name, dtype="float32", quantization_config=quantization_config ) tokenizer = AutoTokenizer.from_pretrained(model_name) trainer = SFTTrainer( model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset, peft_config=self.peft_config, ) assert isinstance(trainer.model, PeftModel) trainer.train() release_memory(model, trainer) @pytest.mark.parametrize("packing", [True, False]) @pytest.mark.parametrize( "model_name", [ "trl-internal-testing/tiny-LlamaForCausalLM-3.2", "trl-internal-testing/tiny-MistralForCausalLM-0.2", ], ) @require_liger_kernel def test_sft_trainer_with_liger(self, model_name, packing): """ Tests if passing use_liger=True to SFTConfig loads and runs the trainer with AutoLigerKernelForCausalLM as expected. """ import importlib def cleanup_liger_patches(trainer): """Clean up liger_kernel patches by reloading the model's specific module""" try: # Get the specific module that was used by the trainer's model module_path = trainer.model.__module__ reload_module = importlib.import_module(module_path) importlib.reload(reload_module) except Exception: pass # Continue if reload fails training_args = SFTConfig( output_dir=self.tmp_dir, logging_strategy="no", report_to="none", per_device_train_batch_size=2, max_steps=2, packing=packing, max_length=self.max_length, use_liger_kernel=True, ) trainer = SFTTrainer( model_name, args=training_args, train_dataset=self.train_dataset, eval_dataset=self.eval_dataset, ) # Ensure cleanup of liger patches after the test try: trainer.train() release_memory(trainer.model, trainer) finally: cleanup_liger_patches(trainer) @pytest.mark.parametrize("packing", [True, False]) @pytest.mark.parametrize( "model_name", [ "trl-internal-testing/tiny-LlamaForCausalLM-3.2", "trl-internal-testing/tiny-MistralForCausalLM-0.2", ], ) @require_torch_accelerator def test_train_offloading(self, model_name, packing): """Test that activation offloading works with SFTTrainer.""" # Initialize the trainer training_args = SFTConfig( output_dir=self.tmp_dir, activation_offloading=True, report_to="none", per_device_train_batch_size=2, max_steps=2, packing=packing, max_length=self.max_length, ) trainer = SFTTrainer( model=model_name, args=training_args, train_dataset=self.train_dataset, eval_dataset=self.eval_dataset ) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} # Train the model trainer.train() # Check that the training loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" release_memory(trainer.model, trainer)