# Copyright 2020-2026 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import textwrap from io import StringIO from unittest.mock import patch import pytest import torch import torch.nn as nn import torch.nn.functional as F import transformers from packaging.version import Version from transformers import AutoModelForCausalLM from transformers.testing_utils import torch_device from transformers.utils import is_peft_available from trl import ModelConfig from trl.trainer.utils import ( RepeatSampler, _ChunkedLogProbFunction, entropy_from_logits, flush_left, generate_model_card, get_peft_config, hash_module, nanstd, pad, patch_chunked_lm_head, print_prompt_completions_sample, selective_log_softmax, shuffle_sequence_dict, split_pixel_values_by_grid, split_tensor_dict, unsplit_pixel_values_by_grid, use_adapter, ) from .testing_utils import TrlTestCase, require_peft, require_rich, require_torch_accelerator if is_peft_available(): from peft import AutoPeftModelForCausalLM, LoraConfig @require_peft class TestUseAdapter(TrlTestCase): def test_disables_on_none(self): model = AutoPeftModelForCausalLM.from_pretrained( "trl-internal-testing/tiny-PeftModel", adapter_name="my_adapter" ) input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]]) with model.disable_adapter(): expected = model(input_ids).logits with use_adapter(model, None): output = model(input_ids).logits assert torch.equal(output, expected) def test_restores_previous_adapter(self): model = AutoPeftModelForCausalLM.from_pretrained( "trl-internal-testing/tiny-PeftModel", adapter_name="my_adapter" ) input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]]) expected = model(input_ids).logits with use_adapter(model, "my_adapter"): pass output = model(input_ids).logits assert torch.equal(output, expected) with use_adapter(model, None): pass output = model(input_ids).logits assert torch.equal(output, expected) def test_with_multiple_adapters(self): model = AutoPeftModelForCausalLM.from_pretrained( "trl-internal-testing/tiny-PeftModel", adapter_name="my_adapter_1" ) model.load_adapter("trl-internal-testing/tiny-PeftModel-2", "my_adapter_2") input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]]) model.set_adapter("my_adapter_1") # should be a no-op, but let's keep it for clarity expected_1 = model(input_ids).logits model.set_adapter("my_adapter_2") expected_2 = model(input_ids).logits with use_adapter(model, "my_adapter_1"): output_1 = model(input_ids).logits with use_adapter(model, "my_adapter_2"): output_2 = model(input_ids).logits assert torch.equal(output_1, expected_1) assert torch.equal(output_2, expected_2) class TestPad(TrlTestCase): def test_pad_1_dim_left(self): x = torch.tensor([1, 2, 3]) y = torch.tensor([4, 5]) output = pad((x, y), padding_value=0, padding_side="left") expected = torch.tensor([[1, 2, 3], [0, 4, 5]]) assert torch.equal(output, expected) def test_pad_1_dim_right(self): x = torch.tensor([1, 2, 3]) y = torch.tensor([4, 5]) output = pad((x, y), padding_value=0, padding_side="right") expected = torch.tensor([[1, 2, 3], [4, 5, 0]]) assert torch.equal(output, expected) def test_pad_2_dim_left(self): x = torch.tensor([[1, 2], [3, 4]]) y = torch.tensor([[5, 6]]) output = pad((x, y), padding_value=0, padding_side="left") expected = torch.tensor( [ [[1, 2], [3, 4]], [[0, 0], [5, 6]], ] ) assert torch.equal(output, expected) def test_pad_2_dim_right(self): x = torch.tensor([[1, 2], [3, 4]]) y = torch.tensor([[5, 6]]) output = pad((x, y), padding_value=0, padding_side="right") expected = torch.tensor( [ [[1, 2], [3, 4]], [[5, 6], [0, 0]], ] ) assert torch.equal(output, expected) def test_pad_2_dim_right_multidim(self): x = torch.tensor([[1, 2], [3, 4]]) y = torch.tensor([[5]]) output = pad((x, y), padding_value=0, padding_side="right") expected = torch.tensor( [ [[1, 2], [3, 4]], [[5, 0], [0, 0]], ] ) assert torch.equal(output, expected) def test_pad_to_multiple_of_1(self): x = torch.tensor([1, 2, 3]) y = torch.tensor([4, 5]) # Max length is 3, pad to multiple of 4 output = pad((x, y), padding_value=0, padding_side="right", pad_to_multiple_of=4) expected = torch.tensor([[1, 2, 3, 0], [4, 5, 0, 0]]) assert torch.equal(output, expected) def test_pad_to_multiple_of_2(self): x = torch.tensor([1, 2, 3, 4, 5]) y = torch.tensor([6, 7, 8]) # Max length is 3, pad to multiple of 4 output = pad((x, y), padding_value=0, padding_side="right", pad_to_multiple_of=4) expected = torch.tensor([[1, 2, 3, 4, 5, 0, 0, 0], [6, 7, 8, 0, 0, 0, 0, 0]]) assert torch.equal(output, expected) def test_pad_to_multiple_of_side_left(self): x = torch.tensor([1, 2, 3, 4, 5]) y = torch.tensor([6, 7, 8]) # Max length is 3, pad to multiple of 4 output = pad((x, y), padding_value=0, padding_side="left", pad_to_multiple_of=4) expected = torch.tensor([[0, 0, 0, 1, 2, 3, 4, 5], [0, 0, 0, 0, 0, 6, 7, 8]]) assert torch.equal(output, expected) def test_pad_to_multiple_of_no_extra_padding(self): x = torch.tensor([1, 2, 3, 4]) y = torch.tensor([5, 6, 7, 8]) # Already multiple of 4 output = pad((x, y), padding_value=0, padding_side="left", pad_to_multiple_of=4) expected = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]) assert torch.equal(output, expected) class TestHashModule(TrlTestCase): def test_hash_module_deterministic_across_order(self): class ModAB(torch.nn.Module): def __init__(self, a: torch.Tensor, b: torch.Tensor): super().__init__() self.a = torch.nn.Parameter(a) self.b = torch.nn.Parameter(b) class ModBA(torch.nn.Module): def __init__(self, a: torch.Tensor, b: torch.Tensor): super().__init__() self.b = torch.nn.Parameter(b) self.a = torch.nn.Parameter(a) a = torch.tensor([[1.0, 2.0]]) b = torch.tensor([3.0]) assert hash_module(ModAB(a, b)) == hash_module(ModBA(a, b)) def test_hash_module_changes_with_value(self): class Mod(torch.nn.Module): def __init__(self, value: float): super().__init__() self.weight = torch.nn.Parameter(torch.tensor([value, 2.0])) assert hash_module(Mod(1.0)) != hash_module(Mod(1.5)) def test_hash_module_includes_dtype(self): class Mod(torch.nn.Module): def __init__(self, dtype: torch.dtype): super().__init__() self.weight = torch.nn.Parameter(torch.tensor([1.0, 2.0], dtype=dtype)) assert hash_module(Mod(torch.float32)) != hash_module(Mod(torch.float16)) def test_hash_module_tiny_model_twice(self): model_id = "trl-internal-testing/tiny-GptOssForCausalLM" model_a = AutoModelForCausalLM.from_pretrained(model_id) model_b = AutoModelForCausalLM.from_pretrained(model_id) assert hash_module(model_a) == hash_module(model_b) def test_hash_module_tiny_model_change_layer(self): model_id = "trl-internal-testing/tiny-GptOssForCausalLM" model = AutoModelForCausalLM.from_pretrained(model_id) h1 = hash_module(model) with torch.no_grad(): model.lm_head.weight.add_(0.01) h2 = hash_module(model) assert h1 != h2 @require_peft class TestGetPEFTConfig(TrlTestCase): def test_create_peft_config_use_peft_false(self): """Test that when use_peft is False, the function returns None.""" model_args = ModelConfig(use_peft=False) peft_config = get_peft_config(model_args) assert peft_config is None def test_create_peft_config_use_peft_true(self): """Test that when use_peft is True, the function returns a LoraConfig object.""" # Provide non-default values to the model config for testing peft_kwargs = { "lora_r": 8, "lora_alpha": 16, "lora_dropout": 0.1, "lora_task_type": "SEQ_CLS", "use_rslora": True, "lora_target_modules": ["up_proj", "down_proj"], "lora_modules_to_save": ["up_proj"], } model_args = ModelConfig(use_peft=True, **peft_kwargs) peft_config = get_peft_config(model_args) assert isinstance(peft_config, LoraConfig) for arg, value in peft_kwargs.items(): # Test that lists of modules are converted to sets if arg == "lora_target_modules": value = set(value) # Rename the argument to match the LoraConfig attribute name if arg in ["lora_r", "lora_task_type", "lora_target_modules", "lora_modules_to_save"]: arg = arg[len("lora_") :] if arg.startswith("lora_") else arg assert getattr(peft_config, arg) == value class TestNanStd(TrlTestCase): def test_nanstd_ignores_nans(self): x = torch.tensor([1.0, 2.0, 3.0, float("nan")]) result = nanstd(x) torch.testing.assert_close(result, torch.tensor(1.0)) def test_nanstd_dim_and_keepdim(self): x = torch.tensor([[1.0, float("nan")], [3.0, 5.0]]) result = nanstd(x, dim=1, keepdim=True) assert torch.isnan(result[0, 0]) torch.testing.assert_close(result[1, 0], torch.tensor(1.4142135), rtol=1e-5, atol=1e-6) def test_nanstd_all_nan(self): x = torch.tensor([float("nan"), float("nan")]) result = nanstd(x) assert torch.isnan(result) class TestGenerateModelCard(TrlTestCase): def test_full(self): model_card = generate_model_card( base_model="username/my_base_model", model_name="my_model", hub_model_id="username/my_hub_model", dataset_name="username/my_dataset", tags=["trl", "trainer-tag"], wandb_url="https://wandb.ai/username/project_id/runs/abcd1234", trackio_url="https://huggingface.co/spaces/username/space_id", comet_url="https://www.comet.com/username/project_id/experiment_id", trainer_name="My Trainer", trainer_citation="@article{my_trainer, ...}", paper_title="My Paper", paper_id="1234.56789", ) card_text = str(model_card) assert "[username/my_base_model](https://huggingface.co/username/my_base_model)" in card_text assert "my_model" in card_text assert 'pipeline("text-generation", model="username/my_hub_model", device="cuda")' in card_text assert "datasets: username/my_dataset" in card_text assert "](https://wandb.ai/username/project_id/runs/abcd1234)" in card_text assert "](https://huggingface.co/spaces/username/space_id)" in card_text assert "](https://www.comet.com/username/project_id/experiment_id" in card_text assert "My Trainer" in card_text assert "```bibtex\n@article{my_trainer, ...}\n```" in card_text assert "[My Paper](https://huggingface.co/papers/1234.56789)" in card_text def test_val_none(self): model_card = generate_model_card( base_model=None, model_name="my_model", hub_model_id="username/my_hub_model", dataset_name=None, tags=[], wandb_url=None, trackio_url=None, comet_url=None, trainer_name="My Trainer", trainer_citation=None, paper_title=None, paper_id=None, ) card_text = str(model_card) assert "my_model" in card_text assert 'pipeline("text-generation", model="username/my_hub_model", device="cuda")' in card_text assert "My Trainer" in card_text class TestFlushLeft(TrlTestCase): def test_basic_case(self): mask = torch.tensor([[0, 0, 1, 1, 1], [0, 1, 1, 0, 0]]) tensor1 = torch.tensor([[0, 0, 2, 3, 4], [0, 5, 6, 0, 0]]) tensor2 = torch.tensor([[0, 0, 7, 8, 9], [0, 10, 11, 0, 0]]) new_mask, new_tensor1, new_tensor2 = flush_left(mask, tensor1, tensor2) expected_mask = torch.tensor([[1, 1, 1], [1, 1, 0]]) expected_tensor1 = torch.tensor([[2, 3, 4], [5, 6, 0]]) expected_tensor2 = torch.tensor([[7, 8, 9], [10, 11, 0]]) assert torch.equal(new_mask, expected_mask) assert torch.equal(new_tensor1, expected_tensor1) assert torch.equal(new_tensor2, expected_tensor2) def test_single_row(self): mask = torch.tensor([[0, 0, 1, 1]]) tensor1 = torch.tensor([[0, 0, 2, 3]]) new_mask, new_tensor1 = flush_left(mask, tensor1) expected_mask = torch.tensor([[1, 1]]) expected_tensor1 = torch.tensor([[2, 3]]) assert torch.equal(new_mask, expected_mask) assert torch.equal(new_tensor1, expected_tensor1) def test_no_shift_needed(self): mask = torch.tensor([[1, 1, 0, 0], [1, 0, 0, 0]]) tensor1 = torch.tensor([[5, 6, 0, 0], [7, 0, 0, 0]]) new_mask, new_tensor1 = flush_left(mask, tensor1) expected_mask = torch.tensor([[1, 1], [1, 0]]) expected_tensor1 = torch.tensor([[5, 6], [7, 0]]) assert torch.equal(new_mask, expected_mask) assert torch.equal(new_tensor1, expected_tensor1) def test_no_tensors(self): mask = torch.tensor([[0, 0, 1, 1, 1], [0, 1, 1, 0, 0]]) new_mask = flush_left(mask) expected_mask = torch.tensor([[1, 1, 1], [1, 1, 0]]) assert torch.equal(new_mask, expected_mask) class TestRepeatRandomSampler(TrlTestCase): def test_sampler(self): dataset = ["a", "b", "c", "d", "e", "f", "g"] sampler = RepeatSampler(dataset, mini_repeat_count=2) # Should output something like [4, 4, 3, 3, 0, 0, 1, 1, 2, 2, 6, 6, 5, 5] sampled = list(sampler) # Check that the length is doubled assert len(sampled) == 2 * len(dataset) # Check that all indexes are present assert set(sampled) == set(range(len(dataset))) # Check that each element is repeated twice assert all(sampled[i] == sampled[i + 1] for i in range(0, len(sampled), 2)) def test_sampler_no_shuffle(self): dataset = ["a", "b", "c", "d", "e", "f", "g"] sampler = RepeatSampler(dataset, mini_repeat_count=2, shuffle=False) sampled = list(sampler) expected = [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6] assert sampled == expected def test_sampler_no_repeat(self): dataset = ["a", "b", "c", "d", "e", "f", "g"] sampler = RepeatSampler(dataset, mini_repeat_count=1) # Should output something like [4, 3, 0, 1, 2, 6, 5] sampled = list(sampler) # Check that the length is the same assert len(sampled) == len(dataset) # Check that all indexes are present assert set(sampled) == set(range(len(dataset))) def test_sampler_with_batch_size(self): dataset = ["a", "b", "c", "d", "e", "f", "g", "h"] sampler = RepeatSampler(dataset, mini_repeat_count=1, batch_size=2, repeat_count=2) # Should output something like [4, 3, 4, 3, 0, 1, 0, 1, 2, 6, 2, 6, 5, 7, 5, 7] sampled = list(sampler) # Check that the length is doubled assert len(sampled) == 2 * len(dataset) # Check that all indexes are present assert set(sampled) == set(range(len(dataset))) # Check that each element is repeated as expected assert all(sampled[i : i + 1] == sampled[i + 2 : i + 3] for i in range(0, len(sampled), 4)) def test_sampler_with_batch_size_and_drop(self): dataset = ["a", "b", "c", "d", "e", "f", "g"] sampler = RepeatSampler(dataset, mini_repeat_count=1, batch_size=2, repeat_count=2) # Should output something like [4, 3, 4, 3, 0, 1, 0, 1, 2, 6, 2, 6] sampled = list(sampler) # Check that the length is doubled assert len(sampled) == 2 * ( len(dataset) - 1 ) # one element is dropped, because it's not enough to form a batch assert len(sampler) == len(sampled) # the length should be the same as the sampled length # Check that the sampled indexes are a subset of the dataset indexes assert set(sampled).issubset(set(range(len(dataset)))) # Check that each element is repeated as expected assert all(sampled[i : i + 1] == sampled[i + 2 : i + 3] for i in range(0, len(sampled), 4)) def test_sampler_with_mini_repeat_count_and_batch_size_1(self): dataset = ["a", "b", "c", "d", "e", "f", "g"] sampler = RepeatSampler(dataset, mini_repeat_count=2, batch_size=3, repeat_count=2) # Should output something like [4, 4, 3, 3, 0, 0, 4, 4, 3, 3, 0, 0, # 1, 1, 2, 2, 6, 6, 1, 1, 2, 2, 6, 6] sampled = list(sampler) # Check that the length is quadrupled assert len(sampled) == 4 * (len(dataset) - 1) # 1 element is dropped, because it's not enough to form a batch assert len(sampler) == len(sampled) # the length should be the same as the sampled length # Check that the sampled indexes are a subset of the dataset indexes assert set(sampled).issubset(set(range(len(dataset)))) # Check that each element is repeated as expected assert all(sampled[i] == sampled[i + 1] for i in range(0, len(sampled), 2)) # Check that the batch is repeated as expected assert sampled[0:6] == sampled[6:12] assert sampled[12:18] == sampled[18:24] def test_sampler_with_mini_repeat_count_and_batch_size_2(self): dataset = ["a", "b", "c", "d", "e", "f", "g"] sampler = RepeatSampler(dataset, mini_repeat_count=3, batch_size=2, repeat_count=2) # Should output something like [4, 4, 4, 3, 3, 3, 4, 4, 4, 3, 3, 3, # 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, # 2, 2, 2, 6, 6, 6, 2, 2, 2, 6, 6, 6] sampled = list(sampler) # Check that the length is sextupled assert len(sampled) == 6 * (len(dataset) - 1) # 1 element is dropped, because it's not enough to form a batch assert len(sampler) == len(sampled) # the length should be the same as the sampled length # Check that the sampled indexes are a subset of the dataset indexes assert set(sampled).issubset(set(range(len(dataset)))) # Check that each element is repeated as expected assert all(sampled[i] == sampled[i + 1] == sampled[i + 2] for i in range(0, len(sampled), 3)) # Check that the batch is repeated as expected assert sampled[0:6] == sampled[6:12] assert sampled[12:18] == sampled[18:24] assert sampled[24:30] == sampled[30:36] def test_sampler_with_mini_repeat_count_and_batch_size_3(self): dataset = ["a", "b", "c", "d", "e", "f", "g"] sampler = RepeatSampler(dataset, mini_repeat_count=2, batch_size=2, repeat_count=3) # Should output something like [4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, # 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, # 2, 2, 6, 6, 2, 2, 6, 6, 2, 2, 6, 6] sampled = list(sampler) # Check that the length is sextupled assert len(sampled) == 6 * (len(dataset) - 1) # 1 element is dropped, because it's not enough to form a batch # Check that the sampled indexes are a subset of the dataset indexes assert set(sampled).issubset(set(range(len(dataset)))) # Check that each element is repeated as expected assert all(sampled[i] == sampled[i + 1] for i in range(0, len(sampled), 2)) # Check that the batch is repeated as expected assert sampled[0:4] == sampled[4:8] == sampled[8:12] assert sampled[12:16] == sampled[16:20] == sampled[20:24] assert sampled[24:28] == sampled[28:32] == sampled[32:36] class TestEntropyFromLogits(TrlTestCase): @pytest.mark.parametrize("shape", [(768,), (32, 768), (8, 16, 768), (2, 4, 8, 768)]) @pytest.mark.parametrize("chunk_size", [1, 16]) @pytest.mark.parametrize("dtype", [torch.float64, torch.float32, torch.float16, torch.bfloat16]) def test_entropy_from_logits_2_dims(self, dtype, chunk_size, shape): logits = torch.randn(*shape, dtype=dtype) if dtype in (torch.float64, torch.float32): p = logits.softmax(-1) entropy = -torch.sum(p * p.log(), dim=-1) else: logps = logits.log_softmax(dim=-1) entropy = -(torch.exp(logps) * logps).sum(-1) predicted_entropy = entropy_from_logits(logits, chunk_size=chunk_size) torch.testing.assert_close(predicted_entropy, entropy, rtol=1e-5, atol=1e-5) @require_rich class TestPrintPromptCompletionsSample(TrlTestCase): @patch("sys.stdout", new_callable=StringIO) def test_print_output(self, mock_stdout): prompts = ["The sky is", "The sun is"] completions = [" blue.", " in the sky."] rewards = {"Correctness": [0.123, 0.456], "Format": [0.789, 0.101]} advantages = [0.987, 0.654] step = 42 print_prompt_completions_sample(prompts, completions, rewards, advantages, step) output = mock_stdout.getvalue() # docstyle-ignore expected_output = textwrap.dedent("""\ ╭──────────────────────────── Step 42 ─────────────────────────────╮ │ ┏━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━┓ │ │ ┃ Prompt ┃ Completion ┃ Correctness ┃ Format ┃ Advantage ┃ │ │ ┡━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━┩ │ │ │ The sky is │ blue. │ 0.12 │ 0.79 │ 0.99 │ │ │ ├────────────┼──────────────┼─────────────┼────────┼───────────┤ │ │ │ The sun is │ in the sky. │ 0.46 │ 0.10 │ 0.65 │ │ │ └────────────┴──────────────┴─────────────┴────────┴───────────┘ │ ╰──────────────────────────────────────────────────────────────────╯ """) assert output == expected_output @patch("sys.stdout", new_callable=StringIO) def test_extra_columns(self, mock_stdout): prompts = ["The sky is", "The sun is"] completions = [" blue.", " in the sky."] rewards = {"Correctness": [0.123, 0.456], "Format": [0.789, 0.101]} advantages = [0.987, 0.654] extra = {"source": ["dataset_A", "dataset_B"]} step = 42 print_prompt_completions_sample(prompts, completions, rewards, advantages, step, extra=extra) output = mock_stdout.getvalue() # docstyle-ignore expected_output = textwrap.dedent("""\ ╭────────────────────────────────── Step 42 ───────────────────────────────────╮ │ ┏━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━┓ │ │ ┃ Prompt ┃ Completion ┃ Correctness ┃ Format ┃ Advantage ┃ source ┃ │ │ ┡━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━┩ │ │ │ The sky is │ blue. │ 0.12 │ 0.79 │ 0.99 │ dataset_A │ │ │ ├────────────┼──────────────┼─────────────┼────────┼───────────┼───────────┤ │ │ │ The sun is │ in the sky. │ 0.46 │ 0.10 │ 0.65 │ dataset_B │ │ │ └────────────┴──────────────┴─────────────┴────────┴───────────┴───────────┘ │ ╰──────────────────────────────────────────────────────────────────────────────╯ """) assert output == expected_output @patch("sys.stdout", new_callable=StringIO) def test_num_samples(self, mock_stdout): prompts = ["A", "B"] completions = ["1", "2"] rewards = {"Score": [0.1, 0.2]} advantages = [0.3, 0.4] step = 10 print_prompt_completions_sample(prompts, completions, rewards, advantages, step, num_samples=1) output = mock_stdout.getvalue() # docstyle-ignore possible_outputs = [ textwrap.dedent("""\ ╭────────────────── Step 10 ──────────────────╮ │ ┏━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━┓ │ │ ┃ Prompt ┃ Completion ┃ Score ┃ Advantage ┃ │ │ ┡━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━┩ │ │ │ A │ 1 │ 0.10 │ 0.30 │ │ │ └────────┴────────────┴───────┴───────────┘ │ ╰─────────────────────────────────────────────╯ """), # docstyle-ignore textwrap.dedent("""\ ╭────────────────── Step 10 ──────────────────╮ │ ┏━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━┓ │ │ ┃ Prompt ┃ Completion ┃ Score ┃ Advantage ┃ │ │ ┡━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━┩ │ │ │ B │ 2 │ 0.20 │ 0.40 │ │ │ └────────┴────────────┴───────┴───────────┘ │ ╰─────────────────────────────────────────────╯ """), ] assert output in possible_outputs @patch("sys.stdout", new_callable=StringIO) def test_print_messages(self, mock_stdout): prompts = [ [ {"role": "system", "content": "You are an helpful assistant."}, {"role": "user", "content": "What color is the sky?"}, ], [ {"role": "system", "content": "You are an helpful assistant."}, {"role": "user", "content": "Where is the sun?"}, ], ] completions = [ [{"role": "assistant", "content": "It is blue."}], [{"role": "assistant", "content": "In the sky."}], ] rewards = {"Correctness": [0.123, 0.456], "Format": [0.789, 0.101]} advantages = [0.987, 0.654] step = 42 print_prompt_completions_sample(prompts, completions, rewards, advantages, step) output = mock_stdout.getvalue() # docstyle-ignore expected_output = textwrap.dedent("""\ ╭────────────────────────────────── Step 42 ───────────────────────────────────╮ │ ┏━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━┓ │ │ ┃ Prompt ┃ Completion ┃ Correctness ┃ Format ┃ Advantage ┃ │ │ ┡━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━┩ │ │ │ SYSTEM │ ASSISTANT │ 0.12 │ 0.79 │ 0.99 │ │ │ │ You are an helpful │ It is blue. │ │ │ │ │ │ │ assistant. │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ USER │ │ │ │ │ │ │ │ What color is the sky? │ │ │ │ │ │ │ ├─────────────────────────┼─────────────┼─────────────┼────────┼───────────┤ │ │ │ SYSTEM │ ASSISTANT │ 0.46 │ 0.10 │ 0.65 │ │ │ │ You are an helpful │ In the sky. │ │ │ │ │ │ │ assistant. │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ USER │ │ │ │ │ │ │ │ Where is the sun? │ │ │ │ │ │ │ └─────────────────────────┴─────────────┴─────────────┴────────┴───────────┘ │ ╰──────────────────────────────────────────────────────────────────────────────╯ """) assert output == expected_output @patch("sys.stdout", new_callable=StringIO) def test_print_messages_with_tools(self, mock_stdout): prompts = [ [{"role": "user", "content": "What is the temperature in Paris?"}], [{"role": "user", "content": "What is the weather in London?"}], ] completions = [ [{"role": "tool", "name": "get_temperature", "args": {"location": "Paris"}}], [{"role": "tool", "name": "get_weather", "args": {"location": "London"}}], ] rewards = {"Correctness": [0.123, 0.456], "Format": [0.789, 0.101]} advantages = [0.987, 0.654] step = 42 print_prompt_completions_sample(prompts, completions, rewards, advantages, step) output = mock_stdout.getvalue() # docstyle-ignore expected_output = textwrap.dedent("""\ ╭────────────────────────────────── Step 42 ───────────────────────────────────╮ │ ┏━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━┓ │ │ ┃ Prompt ┃ Completion ┃ Correctness ┃ Format ┃ Advantage ┃ │ │ ┡━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━┩ │ │ │ USER │ TOOL │ 0.12 │ 0.79 │ 0.99 │ │ │ │ What is the │ get_temperature(… │ │ │ │ │ │ │ temperature in │ 'Paris'}) │ │ │ │ │ │ │ Paris? │ │ │ │ │ │ │ ├───────────────────┼───────────────────┼─────────────┼────────┼───────────┤ │ │ │ USER │ TOOL │ 0.46 │ 0.10 │ 0.65 │ │ │ │ What is the │ get_weather({'lo… │ │ │ │ │ │ │ weather in │ 'London'}) │ │ │ │ │ │ │ London? │ │ │ │ │ │ │ └───────────────────┴───────────────────┴─────────────┴────────┴───────────┘ │ ╰──────────────────────────────────────────────────────────────────────────────╯ """) assert output == expected_output @patch("sys.stdout", new_callable=StringIO) def test_print_messages_with_reasoning_content(self, mock_stdout): prompts = [[{"role": "user", "content": "What color is the sky?"}]] completions = [[{"role": "assistant", "reasoning_content": "I think it is blue.", "content": "It is blue."}]] rewards = {"Score": [0.5]} advantages = [0.9] step = 1 print_prompt_completions_sample(prompts, completions, rewards, advantages, step) output = mock_stdout.getvalue() # docstyle-ignore expected_output = textwrap.dedent("""\ ╭─────────────────────────────── Step 1 ───────────────────────────────╮ │ ┏━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━┓ │ │ ┃ Prompt ┃ Completion ┃ Score ┃ Advantage ┃ │ │ ┡━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━┩ │ │ │ USER │ ASSISTANT │ 0.50 │ 0.90 │ │ │ │ What color is the sky? │ I think it is blue. │ │ │ │ │ │ │ It is blue. │ │ │ │ │ └────────────────────────┴─────────────────────┴───────┴───────────┘ │ ╰──────────────────────────────────────────────────────────────────────╯ """) assert output == expected_output @patch("sys.stdout", new_callable=StringIO) def test_print_messages_with_thinking(self, mock_stdout): prompts = [[{"role": "user", "content": "What color is the sky?"}]] completions = [[{"role": "assistant", "thinking": "I think it is blue.", "content": "It is blue."}]] rewards = {"Score": [0.5]} advantages = [0.9] step = 1 print_prompt_completions_sample(prompts, completions, rewards, advantages, step) output = mock_stdout.getvalue() # docstyle-ignore expected_output = textwrap.dedent("""\ ╭─────────────────────────────── Step 1 ───────────────────────────────╮ │ ┏━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━┓ │ │ ┃ Prompt ┃ Completion ┃ Score ┃ Advantage ┃ │ │ ┡━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━┩ │ │ │ USER │ ASSISTANT │ 0.50 │ 0.90 │ │ │ │ What color is the sky? │ I think it is blue. │ │ │ │ │ │ │ It is blue. │ │ │ │ │ └────────────────────────┴─────────────────────┴───────┴───────────┘ │ ╰──────────────────────────────────────────────────────────────────────╯ """) assert output == expected_output class TestSelectiveLogSoftmax(TrlTestCase): @pytest.mark.parametrize("dtype", [torch.float64, torch.float32, torch.float16, torch.bfloat16]) def test_selective_log_softmax(self, dtype): """Test selective_log_softmax with logits of different dtypes""" vocab_size = 1024 batch_size = 4 seq_len = 32 input_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len)) logits = torch.randn(batch_size, seq_len, vocab_size, dtype=dtype) expected_output = torch.gather(logits.log_softmax(-1), dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1) actual_output = selective_log_softmax(logits, input_ids) if dtype in [torch.float16, torch.bfloat16]: # half-precision dtypes fall back to an exact method assert torch.equal(actual_output, expected_output) else: torch.testing.assert_close(actual_output, expected_output, rtol=1e-5, atol=1e-5) @pytest.mark.parametrize("dtype", [torch.float64, torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("k", [1, 8]) def test_selective_log_softmax_multi_index(self, dtype, k): """Test selective_log_softmax with logits of different dtypes and index widths""" vocab_size = 1024 batch_size = 4 seq_len = 32 index = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len, k)) logits = torch.randn(batch_size, seq_len, vocab_size, dtype=dtype) expected_output = torch.gather(logits.log_softmax(-1), dim=-1, index=index) actual_output = selective_log_softmax(logits, index) assert actual_output.shape == (batch_size, seq_len, k) if dtype in [torch.float16, torch.bfloat16]: # half-precision dtypes fall back to an exact method assert torch.equal(actual_output, expected_output) else: torch.testing.assert_close(actual_output, expected_output, rtol=1e-5, atol=1e-5) class TestShuffleSequenceDict(TrlTestCase): def test_shuffle_preserves_shape(self): x = torch.arange(6).reshape(3, 2) y = torch.arange(3).reshape(3, 1) tensor_dict = {"x": x.clone(), "y": y.clone()} shuffled = shuffle_sequence_dict(tensor_dict) assert shuffled["x"].shape == x.shape assert shuffled["y"].shape == y.shape def test_shuffle_consistent_across_tensors(self): # Use known patterns to check alignment x = torch.tensor([[10, 11], [20, 21], [30, 31]]) y = torch.tensor([[1], [2], [3]]) tensor_dict = {"x": x.clone(), "y": y.clone()} shuffled = shuffle_sequence_dict(tensor_dict) # Build a reverse map from shuffled x rows to y values for i in range(3): x_row = shuffled["x"][i] y_val = shuffled["y"][i].item() if torch.equal(x_row, torch.tensor([10, 11])): assert y_val == 1 elif torch.equal(x_row, torch.tensor([20, 21])): assert y_val == 2 elif torch.equal(x_row, torch.tensor([30, 31])): assert y_val == 3 else: pytest.fail("Unexpected x row in shuffled output.") def test_none_tensor_remains_none(self): x = torch.arange(6).reshape(3, 2) tensor_dict = {"x": x.clone(), "y": None} shuffled = shuffle_sequence_dict(tensor_dict) assert shuffled["y"] is None assert shuffled["x"].shape == x.shape def test_shuffle_with_list(self): x = torch.tensor([[10, 11], [20, 21], [30, 31]]) y = ["a", "b", "c"] sequence_dict = {"x": x.clone(), "y": y} shuffled = shuffle_sequence_dict(sequence_dict) # Check that the list y is shuffled in the same order as x for i in range(3): x_row = shuffled["x"][i] y_val = shuffled["y"][i] if torch.equal(x_row, torch.tensor([10, 11])): assert y_val == "a" elif torch.equal(x_row, torch.tensor([20, 21])): assert y_val == "b" elif torch.equal(x_row, torch.tensor([30, 31])): assert y_val == "c" else: pytest.fail("Unexpected x row in shuffled output.") class TestSplitTensorDict(TrlTestCase): def test_split_equal_chunks(self): x = torch.arange(12).reshape(6, 2) y = torch.arange(6).reshape(6, 1) tensor_dict = {"x": x, "y": y} result = split_tensor_dict(tensor_dict, 3) expected_x_chunks = torch.chunk(x, 3, dim=0) expected_y_chunks = torch.chunk(y, 3, dim=0) assert len(result) == 3 for i in range(3): assert torch.equal(result[i]["x"], expected_x_chunks[i]) assert torch.equal(result[i]["y"], expected_y_chunks[i]) def test_with_none_tensor(self): x = torch.arange(12).reshape(6, 2) tensor_dict = {"x": x, "y": None} result = split_tensor_dict(tensor_dict, 2) expected_x_chunks = torch.chunk(x, 2, dim=0) assert len(result) == 2 for i in range(2): assert torch.equal(result[i]["x"], expected_x_chunks[i]) assert result[i]["y"] is None def test_with_scalar(self): x = torch.arange(12).reshape(6, 2) tensor_dict = {"x": x, "y": torch.tensor(1)} result = split_tensor_dict(tensor_dict, 2) expected_x_chunks = torch.chunk(x, 2, dim=0) assert len(result) == 2 for i in range(2): assert torch.equal(result[i]["x"], expected_x_chunks[i]) assert torch.equal(result[i]["y"], torch.tensor(1)) class TestSplitPixelValuesByGrid(TrlTestCase): def test_split_correctly_0(self): batch = { "image_grid_thw": torch.tensor([[1, 2, 2], [1, 2, 2]]), "num_images": [1, 1], "pixel_values": torch.arange(8 * 3).reshape(8, 3), # Shape: [8, 3] } result = split_pixel_values_by_grid(batch) assert isinstance(result["pixel_values"], list) assert len(result["pixel_values"]) == 2 assert torch.equal(result["pixel_values"][0], batch["pixel_values"][:4]) assert torch.equal(result["pixel_values"][1], batch["pixel_values"][4:]) assert isinstance(result["image_grid_thw"], list) assert len(result["image_grid_thw"]) == 2 assert torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 2, 2]])) assert torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2]])) def test_split_correctly_1(self): batch = { "image_grid_thw": torch.tensor([[1, 2, 2], [1, 2, 4]]), "num_images": [1, 1], "pixel_values": torch.arange(12 * 3).reshape(12, 3), # Shape: [12, 3] } result = split_pixel_values_by_grid(batch) assert isinstance(result["pixel_values"], list) assert len(result["pixel_values"]) == 2 assert torch.equal(result["pixel_values"][0], batch["pixel_values"][:4]) assert torch.equal(result["pixel_values"][1], batch["pixel_values"][4:12]) assert isinstance(result["image_grid_thw"], list) assert len(result["image_grid_thw"]) == 2 assert torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 2, 2]])) assert torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 4]])) def test_missing_keys(self): batch = {"pixel_values": torch.tensor([1.0])} result = split_pixel_values_by_grid(batch) assert result == batch def test_mismatched_length(self): batch = { "image_grid_thw": torch.tensor([[1, 1, 2], [1, 2, 1]]), # Total = 8 "num_images": [1, 1], "pixel_values": torch.randn(3, 5), # Only 3 rows } with pytest.raises(ValueError): split_pixel_values_by_grid(batch) def test_multi_images(self): batch = { "image_grid_thw": torch.tensor([[1, 1, 2], [1, 2, 2], [1, 2, 1]]), # Total = 8 "num_images": [1, 2], "pixel_values": torch.arange(8 * 3).reshape(8, 3), # Shape: [8, 3] } result = split_pixel_values_by_grid(batch) assert isinstance(result["pixel_values"], list) assert len(result["pixel_values"]) == 2 assert torch.equal(result["pixel_values"][0], batch["pixel_values"][:2]) assert torch.equal(result["pixel_values"][1], batch["pixel_values"][2:]) assert isinstance(result["image_grid_thw"], list) assert len(result["image_grid_thw"]) == 2 assert torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 1, 2]])) assert torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2], [1, 2, 1]])) def test_split_by_image_position_ids(self): # Gemma-style: no image_grid_thw, split by num_images using image_position_ids batch = { "num_images": [1, 2], "pixel_values": torch.arange(3 * 4).reshape(3, 4), "image_position_ids": torch.tensor([[0, 1], [2, 3], [4, 5]]), } result = split_pixel_values_by_grid(batch) assert isinstance(result["pixel_values"], list) assert len(result["pixel_values"]) == 2 assert torch.equal(result["pixel_values"][0], batch["pixel_values"][:1]) assert torch.equal(result["pixel_values"][1], batch["pixel_values"][1:]) assert isinstance(result["image_position_ids"], list) assert len(result["image_position_ids"]) == 2 assert torch.equal(result["image_position_ids"][0], batch["image_position_ids"][:1]) assert torch.equal(result["image_position_ids"][1], batch["image_position_ids"][1:]) class TestUnsplitPixelValuesByGrid(TrlTestCase): def test_unsplit_correctly(self): pixel_values = [torch.randn(4, 5), torch.randn(2, 5)] pixel_values_merged = torch.cat(pixel_values, dim=0) image_grid_thw = [torch.tensor([[1, 2, 2]]), torch.tensor([[1, 2, 1]])] image_grid_thw_merged = torch.cat(image_grid_thw, dim=0) batch = {"pixel_values": pixel_values, "image_grid_thw": image_grid_thw, "other_key": torch.tensor([1])} result = unsplit_pixel_values_by_grid(batch) assert isinstance(result["pixel_values"], torch.Tensor) torch.testing.assert_close(result["pixel_values"], pixel_values_merged) assert isinstance(result["image_grid_thw"], torch.Tensor) assert torch.equal(result["image_grid_thw"], image_grid_thw_merged) assert "other_key" in result def test_unsplit_image_position_ids(self): image_position_ids = [torch.tensor([[0, 1]]), torch.tensor([[2, 3], [4, 5]])] image_position_ids_merged = torch.cat(image_position_ids, dim=0) pixel_values = [torch.randn(1, 4), torch.randn(2, 4)] batch = {"pixel_values": pixel_values, "image_position_ids": image_position_ids} result = unsplit_pixel_values_by_grid(batch) assert isinstance(result["image_position_ids"], torch.Tensor) assert torch.equal(result["image_position_ids"], image_position_ids_merged) def test_no_op_if_not_list(self): original = torch.randn(5, 3) batch = {"pixel_values": original} result = unsplit_pixel_values_by_grid(batch) assert torch.equal(result["pixel_values"], original) class TestChunkedLogProbFunction: N, H, V = 64, 32, 128 CHUNK_SIZE = 32 def _reference_logprobs_and_entropy(self, hidden, weight, labels, temperature): logits = (hidden @ weight.t()).to(torch.float32) / temperature # [N, V] log_p = F.log_softmax(logits, dim=-1) logprobs = log_p.gather(-1, labels.unsqueeze(-1)).squeeze(-1) p = torch.softmax(logits, dim=-1) entropy = -(p * log_p).sum(dim=-1) return logprobs, entropy @pytest.mark.parametrize("temperature", [1.0, 0.7]) def test_forward(self, temperature): torch.manual_seed(42) hidden = torch.randn(self.N, self.H) weight = torch.randn(self.V, self.H) labels = torch.randint(0, self.V, (self.N,)) logprobs_chunked, entropy_chunked = _ChunkedLogProbFunction.apply( hidden, weight, labels, temperature, self.CHUNK_SIZE ) logprobs_ref, entropy_ref = self._reference_logprobs_and_entropy(hidden, weight, labels, temperature) torch.testing.assert_close(logprobs_chunked, logprobs_ref, atol=1e-5, rtol=1e-5) torch.testing.assert_close(entropy_chunked, entropy_ref, atol=1e-5, rtol=1e-5) @pytest.mark.parametrize("temperature", [1.0, 0.7]) def test_backward(self, temperature): torch.manual_seed(42) hidden = torch.randn(self.N, self.H, requires_grad=True) weight = torch.randn(self.V, self.H, requires_grad=True) labels = torch.randint(0, self.V, (self.N,)) # Chunked backward logprobs_chunked, _ = _ChunkedLogProbFunction.apply(hidden, weight, labels, temperature, self.CHUNK_SIZE) logprobs_chunked.sum().backward() grad_hidden_chunked = hidden.grad.clone() grad_weight_chunked = weight.grad.clone() hidden.grad = None weight.grad = None # Reference backward logprobs_ref, _ = self._reference_logprobs_and_entropy(hidden, weight, labels, temperature) logprobs_ref.sum().backward() torch.testing.assert_close(grad_hidden_chunked, hidden.grad, atol=1e-5, rtol=1e-5) torch.testing.assert_close(grad_weight_chunked, weight.grad, atol=1e-5, rtol=1e-5) @pytest.mark.parametrize("temperature", [1.0, 0.7]) def test_backward_bfloat16(self, temperature): torch.manual_seed(42) hidden = torch.randn(self.N, self.H, dtype=torch.bfloat16, requires_grad=True) weight = torch.randn(self.V, self.H, dtype=torch.bfloat16, requires_grad=True) labels = torch.randint(0, self.V, (self.N,)) # Chunked backward logprobs_chunked, _ = _ChunkedLogProbFunction.apply(hidden, weight, labels, temperature, self.CHUNK_SIZE) logprobs_chunked.sum().backward() grad_hidden_chunked = hidden.grad.clone() grad_weight_chunked = weight.grad.clone() hidden.grad = None weight.grad = None # Reference backward logprobs_ref, _ = self._reference_logprobs_and_entropy(hidden, weight, labels, temperature) logprobs_ref.sum().backward() torch.testing.assert_close(grad_hidden_chunked, hidden.grad, atol=1e-2, rtol=1e-2) torch.testing.assert_close(grad_weight_chunked, weight.grad, atol=1e-2, rtol=1e-2) class _FakeTransformerModel(nn.Module): """Minimal stand-in for a transformer body: returns random hidden states of the right shape.""" def __init__(self, hidden_size): super().__init__() self.hidden_size = hidden_size self._hidden = None def forward(self, input_ids, attention_mask=None, use_cache=False, **kwargs): b, s = input_ids.shape if self._hidden is None or self._hidden.shape[:2] != (b, s): torch.manual_seed(123) self._hidden = torch.randn(b, s, self.hidden_size, requires_grad=True) return type("Out", (), {"last_hidden_state": self._hidden})() class _FakeCausalLM(nn.Module): """Minimal CausalLM with .model and .lm_head, enough for patch_chunked_lm_head.""" def __init__(self, hidden_size, vocab_size): super().__init__() self.config = type("Config", (), {})() self.model = _FakeTransformerModel(hidden_size) self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False) def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): raise NotImplementedError("should be monkey-patched") _CHUNKED_LM_HEAD_MODEL_IDS = [ "trl-internal-testing/tiny-CohereForCausalLM", "trl-internal-testing/tiny-Cohere2ForCausalLM", pytest.param( "trl-internal-testing/tiny-DeepseekV3ForCausalLM", marks=pytest.mark.skipif( Version(transformers.__version__) < Version("5.0.0"), reason="DeepseekV3 SDPA attention is broken in transformers < 5.0.0", ), ), pytest.param( "trl-internal-testing/tiny-DeepseekV3ForCausalLM-0528", marks=pytest.mark.skipif( Version(transformers.__version__) < Version("5.0.0"), reason="DeepseekV3 SDPA attention is broken in transformers < 5.0.0", ), ), "trl-internal-testing/tiny-Gemma2ForCausalLM", "trl-internal-testing/tiny-GemmaForCausalLM", "trl-internal-testing/tiny-Glm4MoeForCausalLM", "trl-internal-testing/tiny-GptOssForCausalLM", "trl-internal-testing/tiny-LlamaForCausalLM-3.1", "trl-internal-testing/tiny-LlamaForCausalLM-3.2", "trl-internal-testing/tiny-LlamaForCausalLM-3", "trl-internal-testing/tiny-MistralForCausalLM-0.1", "trl-internal-testing/tiny-MistralForCausalLM-0.2", "trl-internal-testing/tiny-Phi3ForCausalLM-3", "trl-internal-testing/tiny-Phi3ForCausalLM-3.5", "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", "trl-internal-testing/tiny-Qwen3ForCausalLM", ] @require_torch_accelerator class TestPatchChunkedLMHead: B, S = 4, 16 # batch size, sequence length (including prompt + completion) H, V = 32, 128 CHUNK_SIZE = 32 def _build_model_and_inputs(self, temperature=1.0): torch.manual_seed(42) model = _FakeCausalLM(self.H, self.V) patch_chunked_lm_head(model, self.CHUNK_SIZE, temperature) input_ids = torch.randint(0, self.V, (self.B, self.S)) attention_mask = torch.ones(self.B, self.S, dtype=torch.long) # First half of each sequence is prompt (0), second half is completion (1) completion_mask = torch.zeros(self.B, self.S, dtype=torch.float32) completion_mask[:, self.S // 2 :] = 1.0 return model, input_ids, attention_mask, completion_mask @pytest.mark.parametrize("temperature", [1.0, 0.7]) def test_dummy_model_chunked_forward_with_completion_mask(self, temperature): """Masked forward matches unmasked forward at completion positions and is zero at prompt positions.""" model, input_ids, attention_mask, completion_mask = self._build_model_and_inputs(temperature) # Run WITHOUT completion_mask (baseline — computes all positions) out_full = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids) # Reset hidden state cache so both runs use the same hidden states model.model._hidden = None # Run WITH completion_mask out_masked = model( input_ids=input_ids, attention_mask=attention_mask, labels=input_ids, completion_mask=completion_mask ) # shifted completion_mask (matching the shift in _chunked_forward) shifted_mask = completion_mask[:, 1:].bool() # At completion positions, values should match torch.testing.assert_close( out_masked["log_probs"][shifted_mask], out_full["log_probs"][shifted_mask], atol=1e-5, rtol=1e-5, ) torch.testing.assert_close( out_masked["entropy"][shifted_mask], out_full["entropy"][shifted_mask], atol=1e-5, rtol=1e-5, ) # At prompt positions, values should be zero prompt_mask = ~shifted_mask assert (out_masked["log_probs"][prompt_mask] == 0).all() assert (out_masked["entropy"][prompt_mask] == 0).all() @pytest.mark.parametrize("temperature", [1.0, 0.7]) def test_dummy_model_chunked_forward_completion_mask_backward(self, temperature): model, input_ids, attention_mask, completion_mask = self._build_model_and_inputs(temperature) # Full forward + backward (mask applied after, as the trainer does) out_full = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids) shifted_mask = completion_mask[:, 1:] loss_full = (out_full["log_probs"] * shifted_mask).sum() loss_full.backward() grad_weight_full = model.lm_head.weight.grad.clone() model.lm_head.weight.grad = None model.model._hidden = None # Masked forward + backward out_masked = model( input_ids=input_ids, attention_mask=attention_mask, labels=input_ids, completion_mask=completion_mask ) loss_masked = (out_masked["log_probs"] * shifted_mask).sum() loss_masked.backward() grad_weight_masked = model.lm_head.weight.grad.clone() torch.testing.assert_close(grad_weight_masked, grad_weight_full, atol=1e-5, rtol=1e-5) @pytest.mark.parametrize("model_id", _CHUNKED_LM_HEAD_MODEL_IDS) @pytest.mark.parametrize("temperature", [1.0, 0.7]) def test_forward(self, model_id, temperature): model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(torch_device) if getattr(model.config, "final_logit_softcapping", None) is not None: pytest.skip("model uses final_logit_softcapping, not supported by chunked LM head") model.eval() B, S, chunk_size = 2, 8, 32 torch.manual_seed(42) input_ids = torch.randint(0, model.config.vocab_size, (B, S), device=torch_device) labels = input_ids.clone() # Reference: standard forward → shifted logits → logprobs & entropy with torch.no_grad(): ref_logits = model(input_ids=input_ids).logits[:, :-1, :].float() / temperature shifted_labels = labels[:, 1:] ref_log_p = F.log_softmax(ref_logits, dim=-1) ref_logprobs = ref_log_p.gather(-1, shifted_labels.unsqueeze(-1)).squeeze(-1) ref_p = ref_logits.softmax(dim=-1) ref_entropy = -(ref_p * ref_log_p).sum(dim=-1) # Chunked forward patch_chunked_lm_head(model, chunk_size, temperature) with torch.no_grad(): out = model(input_ids=input_ids, labels=labels) torch.testing.assert_close(out["log_probs"], ref_logprobs, atol=5e-3, rtol=5e-3) torch.testing.assert_close(out["entropy"], ref_entropy, atol=5e-3, rtol=5e-3) @pytest.mark.parametrize("model_id", _CHUNKED_LM_HEAD_MODEL_IDS) @pytest.mark.parametrize("temperature", [1.0, 0.7]) def test_backward(self, model_id, temperature): model_ref = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(torch_device) if getattr(model_ref.config, "final_logit_softcapping", None) is not None: pytest.skip("model uses final_logit_softcapping, not supported by chunked LM head") model_chunked = copy.deepcopy(model_ref) B, S, chunk_size = 2, 8, 32 torch.manual_seed(42) input_ids = torch.randint(0, model_ref.config.vocab_size, (B, S), device=torch_device) labels = input_ids.clone() shifted_labels = labels[:, 1:] # Reference backward: standard logits → logprobs → backward ref_logits = model_ref(input_ids=input_ids).logits[:, :-1, :].float() / temperature ref_log_p = F.log_softmax(ref_logits, dim=-1) ref_logprobs = ref_log_p.gather(-1, shifted_labels.unsqueeze(-1)).squeeze(-1) ref_logprobs.sum().backward() ref_grad = model_ref.lm_head.weight.grad.clone() # Chunked backward patch_chunked_lm_head(model_chunked, chunk_size, temperature) out = model_chunked(input_ids=input_ids, labels=labels) out["log_probs"].sum().backward() chunked_grad = model_chunked.lm_head.weight.grad.clone() torch.testing.assert_close(chunked_grad, ref_grad, atol=5e-2, rtol=5e-2)