|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import copy
|
| import textwrap
|
| from io import StringIO
|
| from unittest.mock import patch
|
|
|
| import pytest
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| import transformers
|
| from packaging.version import Version
|
| from transformers import AutoModelForCausalLM
|
| from transformers.testing_utils import torch_device
|
| from transformers.utils import is_peft_available
|
|
|
| from trl import ModelConfig
|
| from trl.trainer.utils import (
|
| RepeatSampler,
|
| _ChunkedLogProbFunction,
|
| entropy_from_logits,
|
| flush_left,
|
| generate_model_card,
|
| get_peft_config,
|
| hash_module,
|
| nanstd,
|
| pad,
|
| patch_chunked_lm_head,
|
| print_prompt_completions_sample,
|
| selective_log_softmax,
|
| shuffle_sequence_dict,
|
| split_pixel_values_by_grid,
|
| split_tensor_dict,
|
| unsplit_pixel_values_by_grid,
|
| use_adapter,
|
| )
|
|
|
| from .testing_utils import TrlTestCase, require_peft, require_rich, require_torch_accelerator
|
|
|
|
|
| if is_peft_available():
|
| from peft import AutoPeftModelForCausalLM, LoraConfig
|
|
|
|
|
| @require_peft
|
| class TestUseAdapter(TrlTestCase):
|
| def test_disables_on_none(self):
|
| model = AutoPeftModelForCausalLM.from_pretrained(
|
| "trl-internal-testing/tiny-PeftModel", adapter_name="my_adapter"
|
| )
|
| input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]])
|
| with model.disable_adapter():
|
| expected = model(input_ids).logits
|
|
|
| with use_adapter(model, None):
|
| output = model(input_ids).logits
|
|
|
| assert torch.equal(output, expected)
|
|
|
| def test_restores_previous_adapter(self):
|
| model = AutoPeftModelForCausalLM.from_pretrained(
|
| "trl-internal-testing/tiny-PeftModel", adapter_name="my_adapter"
|
| )
|
| input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]])
|
| expected = model(input_ids).logits
|
| with use_adapter(model, "my_adapter"):
|
| pass
|
| output = model(input_ids).logits
|
| assert torch.equal(output, expected)
|
|
|
| with use_adapter(model, None):
|
| pass
|
| output = model(input_ids).logits
|
| assert torch.equal(output, expected)
|
|
|
| def test_with_multiple_adapters(self):
|
| model = AutoPeftModelForCausalLM.from_pretrained(
|
| "trl-internal-testing/tiny-PeftModel", adapter_name="my_adapter_1"
|
| )
|
| model.load_adapter("trl-internal-testing/tiny-PeftModel-2", "my_adapter_2")
|
| input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]])
|
|
|
| model.set_adapter("my_adapter_1")
|
| expected_1 = model(input_ids).logits
|
| model.set_adapter("my_adapter_2")
|
| expected_2 = model(input_ids).logits
|
|
|
| with use_adapter(model, "my_adapter_1"):
|
| output_1 = model(input_ids).logits
|
|
|
| with use_adapter(model, "my_adapter_2"):
|
| output_2 = model(input_ids).logits
|
|
|
| assert torch.equal(output_1, expected_1)
|
| assert torch.equal(output_2, expected_2)
|
|
|
|
|
| class TestPad(TrlTestCase):
|
| def test_pad_1_dim_left(self):
|
| x = torch.tensor([1, 2, 3])
|
| y = torch.tensor([4, 5])
|
| output = pad((x, y), padding_value=0, padding_side="left")
|
| expected = torch.tensor([[1, 2, 3], [0, 4, 5]])
|
| assert torch.equal(output, expected)
|
|
|
| def test_pad_1_dim_right(self):
|
| x = torch.tensor([1, 2, 3])
|
| y = torch.tensor([4, 5])
|
| output = pad((x, y), padding_value=0, padding_side="right")
|
| expected = torch.tensor([[1, 2, 3], [4, 5, 0]])
|
| assert torch.equal(output, expected)
|
|
|
| def test_pad_2_dim_left(self):
|
| x = torch.tensor([[1, 2], [3, 4]])
|
| y = torch.tensor([[5, 6]])
|
| output = pad((x, y), padding_value=0, padding_side="left")
|
| expected = torch.tensor(
|
| [
|
| [[1, 2], [3, 4]],
|
| [[0, 0], [5, 6]],
|
| ]
|
| )
|
| assert torch.equal(output, expected)
|
|
|
| def test_pad_2_dim_right(self):
|
| x = torch.tensor([[1, 2], [3, 4]])
|
| y = torch.tensor([[5, 6]])
|
| output = pad((x, y), padding_value=0, padding_side="right")
|
| expected = torch.tensor(
|
| [
|
| [[1, 2], [3, 4]],
|
| [[5, 6], [0, 0]],
|
| ]
|
| )
|
| assert torch.equal(output, expected)
|
|
|
| def test_pad_2_dim_right_multidim(self):
|
| x = torch.tensor([[1, 2], [3, 4]])
|
| y = torch.tensor([[5]])
|
| output = pad((x, y), padding_value=0, padding_side="right")
|
| expected = torch.tensor(
|
| [
|
| [[1, 2], [3, 4]],
|
| [[5, 0], [0, 0]],
|
| ]
|
| )
|
| assert torch.equal(output, expected)
|
|
|
| def test_pad_to_multiple_of_1(self):
|
| x = torch.tensor([1, 2, 3])
|
| y = torch.tensor([4, 5])
|
|
|
| output = pad((x, y), padding_value=0, padding_side="right", pad_to_multiple_of=4)
|
| expected = torch.tensor([[1, 2, 3, 0], [4, 5, 0, 0]])
|
| assert torch.equal(output, expected)
|
|
|
| def test_pad_to_multiple_of_2(self):
|
| x = torch.tensor([1, 2, 3, 4, 5])
|
| y = torch.tensor([6, 7, 8])
|
|
|
| output = pad((x, y), padding_value=0, padding_side="right", pad_to_multiple_of=4)
|
| expected = torch.tensor([[1, 2, 3, 4, 5, 0, 0, 0], [6, 7, 8, 0, 0, 0, 0, 0]])
|
| assert torch.equal(output, expected)
|
|
|
| def test_pad_to_multiple_of_side_left(self):
|
| x = torch.tensor([1, 2, 3, 4, 5])
|
| y = torch.tensor([6, 7, 8])
|
|
|
| output = pad((x, y), padding_value=0, padding_side="left", pad_to_multiple_of=4)
|
| expected = torch.tensor([[0, 0, 0, 1, 2, 3, 4, 5], [0, 0, 0, 0, 0, 6, 7, 8]])
|
| assert torch.equal(output, expected)
|
|
|
| def test_pad_to_multiple_of_no_extra_padding(self):
|
| x = torch.tensor([1, 2, 3, 4])
|
| y = torch.tensor([5, 6, 7, 8])
|
|
|
| output = pad((x, y), padding_value=0, padding_side="left", pad_to_multiple_of=4)
|
| expected = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
|
| assert torch.equal(output, expected)
|
|
|
|
|
| class TestHashModule(TrlTestCase):
|
| def test_hash_module_deterministic_across_order(self):
|
| class ModAB(torch.nn.Module):
|
| def __init__(self, a: torch.Tensor, b: torch.Tensor):
|
| super().__init__()
|
| self.a = torch.nn.Parameter(a)
|
| self.b = torch.nn.Parameter(b)
|
|
|
| class ModBA(torch.nn.Module):
|
| def __init__(self, a: torch.Tensor, b: torch.Tensor):
|
| super().__init__()
|
| self.b = torch.nn.Parameter(b)
|
| self.a = torch.nn.Parameter(a)
|
|
|
| a = torch.tensor([[1.0, 2.0]])
|
| b = torch.tensor([3.0])
|
| assert hash_module(ModAB(a, b)) == hash_module(ModBA(a, b))
|
|
|
| def test_hash_module_changes_with_value(self):
|
| class Mod(torch.nn.Module):
|
| def __init__(self, value: float):
|
| super().__init__()
|
| self.weight = torch.nn.Parameter(torch.tensor([value, 2.0]))
|
|
|
| assert hash_module(Mod(1.0)) != hash_module(Mod(1.5))
|
|
|
| def test_hash_module_includes_dtype(self):
|
| class Mod(torch.nn.Module):
|
| def __init__(self, dtype: torch.dtype):
|
| super().__init__()
|
| self.weight = torch.nn.Parameter(torch.tensor([1.0, 2.0], dtype=dtype))
|
|
|
| assert hash_module(Mod(torch.float32)) != hash_module(Mod(torch.float16))
|
|
|
| def test_hash_module_tiny_model_twice(self):
|
| model_id = "trl-internal-testing/tiny-GptOssForCausalLM"
|
| model_a = AutoModelForCausalLM.from_pretrained(model_id)
|
| model_b = AutoModelForCausalLM.from_pretrained(model_id)
|
| assert hash_module(model_a) == hash_module(model_b)
|
|
|
| def test_hash_module_tiny_model_change_layer(self):
|
| model_id = "trl-internal-testing/tiny-GptOssForCausalLM"
|
| model = AutoModelForCausalLM.from_pretrained(model_id)
|
| h1 = hash_module(model)
|
| with torch.no_grad():
|
| model.lm_head.weight.add_(0.01)
|
| h2 = hash_module(model)
|
| assert h1 != h2
|
|
|
|
|
| @require_peft
|
| class TestGetPEFTConfig(TrlTestCase):
|
| def test_create_peft_config_use_peft_false(self):
|
| """Test that when use_peft is False, the function returns None."""
|
| model_args = ModelConfig(use_peft=False)
|
| peft_config = get_peft_config(model_args)
|
| assert peft_config is None
|
|
|
| def test_create_peft_config_use_peft_true(self):
|
| """Test that when use_peft is True, the function returns a LoraConfig object."""
|
|
|
| peft_kwargs = {
|
| "lora_r": 8,
|
| "lora_alpha": 16,
|
| "lora_dropout": 0.1,
|
| "lora_task_type": "SEQ_CLS",
|
| "use_rslora": True,
|
| "lora_target_modules": ["up_proj", "down_proj"],
|
| "lora_modules_to_save": ["up_proj"],
|
| }
|
| model_args = ModelConfig(use_peft=True, **peft_kwargs)
|
| peft_config = get_peft_config(model_args)
|
| assert isinstance(peft_config, LoraConfig)
|
| for arg, value in peft_kwargs.items():
|
|
|
| if arg == "lora_target_modules":
|
| value = set(value)
|
|
|
| if arg in ["lora_r", "lora_task_type", "lora_target_modules", "lora_modules_to_save"]:
|
| arg = arg[len("lora_") :] if arg.startswith("lora_") else arg
|
|
|
| assert getattr(peft_config, arg) == value
|
|
|
|
|
| class TestNanStd(TrlTestCase):
|
| def test_nanstd_ignores_nans(self):
|
| x = torch.tensor([1.0, 2.0, 3.0, float("nan")])
|
| result = nanstd(x)
|
| torch.testing.assert_close(result, torch.tensor(1.0))
|
|
|
| def test_nanstd_dim_and_keepdim(self):
|
| x = torch.tensor([[1.0, float("nan")], [3.0, 5.0]])
|
| result = nanstd(x, dim=1, keepdim=True)
|
| assert torch.isnan(result[0, 0])
|
| torch.testing.assert_close(result[1, 0], torch.tensor(1.4142135), rtol=1e-5, atol=1e-6)
|
|
|
| def test_nanstd_all_nan(self):
|
| x = torch.tensor([float("nan"), float("nan")])
|
| result = nanstd(x)
|
| assert torch.isnan(result)
|
|
|
|
|
| class TestGenerateModelCard(TrlTestCase):
|
| def test_full(self):
|
| model_card = generate_model_card(
|
| base_model="username/my_base_model",
|
| model_name="my_model",
|
| hub_model_id="username/my_hub_model",
|
| dataset_name="username/my_dataset",
|
| tags=["trl", "trainer-tag"],
|
| wandb_url="https://wandb.ai/username/project_id/runs/abcd1234",
|
| trackio_url="https://huggingface.co/spaces/username/space_id",
|
| comet_url="https://www.comet.com/username/project_id/experiment_id",
|
| trainer_name="My Trainer",
|
| trainer_citation="@article{my_trainer, ...}",
|
| paper_title="My Paper",
|
| paper_id="1234.56789",
|
| )
|
| card_text = str(model_card)
|
| assert "[username/my_base_model](https://huggingface.co/username/my_base_model)" in card_text
|
| assert "my_model" in card_text
|
| assert 'pipeline("text-generation", model="username/my_hub_model", device="cuda")' in card_text
|
| assert "datasets: username/my_dataset" in card_text
|
| assert "](https://wandb.ai/username/project_id/runs/abcd1234)" in card_text
|
| assert "](https://huggingface.co/spaces/username/space_id)" in card_text
|
| assert "](https://www.comet.com/username/project_id/experiment_id" in card_text
|
| assert "My Trainer" in card_text
|
| assert "```bibtex\n@article{my_trainer, ...}\n```" in card_text
|
| assert "[My Paper](https://huggingface.co/papers/1234.56789)" in card_text
|
|
|
| def test_val_none(self):
|
| model_card = generate_model_card(
|
| base_model=None,
|
| model_name="my_model",
|
| hub_model_id="username/my_hub_model",
|
| dataset_name=None,
|
| tags=[],
|
| wandb_url=None,
|
| trackio_url=None,
|
| comet_url=None,
|
| trainer_name="My Trainer",
|
| trainer_citation=None,
|
| paper_title=None,
|
| paper_id=None,
|
| )
|
| card_text = str(model_card)
|
| assert "my_model" in card_text
|
| assert 'pipeline("text-generation", model="username/my_hub_model", device="cuda")' in card_text
|
| assert "My Trainer" in card_text
|
|
|
|
|
| class TestFlushLeft(TrlTestCase):
|
| def test_basic_case(self):
|
| mask = torch.tensor([[0, 0, 1, 1, 1], [0, 1, 1, 0, 0]])
|
| tensor1 = torch.tensor([[0, 0, 2, 3, 4], [0, 5, 6, 0, 0]])
|
| tensor2 = torch.tensor([[0, 0, 7, 8, 9], [0, 10, 11, 0, 0]])
|
| new_mask, new_tensor1, new_tensor2 = flush_left(mask, tensor1, tensor2)
|
|
|
| expected_mask = torch.tensor([[1, 1, 1], [1, 1, 0]])
|
| expected_tensor1 = torch.tensor([[2, 3, 4], [5, 6, 0]])
|
| expected_tensor2 = torch.tensor([[7, 8, 9], [10, 11, 0]])
|
|
|
| assert torch.equal(new_mask, expected_mask)
|
| assert torch.equal(new_tensor1, expected_tensor1)
|
| assert torch.equal(new_tensor2, expected_tensor2)
|
|
|
| def test_single_row(self):
|
| mask = torch.tensor([[0, 0, 1, 1]])
|
| tensor1 = torch.tensor([[0, 0, 2, 3]])
|
| new_mask, new_tensor1 = flush_left(mask, tensor1)
|
|
|
| expected_mask = torch.tensor([[1, 1]])
|
| expected_tensor1 = torch.tensor([[2, 3]])
|
|
|
| assert torch.equal(new_mask, expected_mask)
|
| assert torch.equal(new_tensor1, expected_tensor1)
|
|
|
| def test_no_shift_needed(self):
|
| mask = torch.tensor([[1, 1, 0, 0], [1, 0, 0, 0]])
|
| tensor1 = torch.tensor([[5, 6, 0, 0], [7, 0, 0, 0]])
|
| new_mask, new_tensor1 = flush_left(mask, tensor1)
|
|
|
| expected_mask = torch.tensor([[1, 1], [1, 0]])
|
| expected_tensor1 = torch.tensor([[5, 6], [7, 0]])
|
|
|
| assert torch.equal(new_mask, expected_mask)
|
| assert torch.equal(new_tensor1, expected_tensor1)
|
|
|
| def test_no_tensors(self):
|
| mask = torch.tensor([[0, 0, 1, 1, 1], [0, 1, 1, 0, 0]])
|
| new_mask = flush_left(mask)
|
| expected_mask = torch.tensor([[1, 1, 1], [1, 1, 0]])
|
| assert torch.equal(new_mask, expected_mask)
|
|
|
|
|
| class TestRepeatRandomSampler(TrlTestCase):
|
| def test_sampler(self):
|
| dataset = ["a", "b", "c", "d", "e", "f", "g"]
|
| sampler = RepeatSampler(dataset, mini_repeat_count=2)
|
|
|
| sampled = list(sampler)
|
|
|
| assert len(sampled) == 2 * len(dataset)
|
|
|
| assert set(sampled) == set(range(len(dataset)))
|
|
|
| assert all(sampled[i] == sampled[i + 1] for i in range(0, len(sampled), 2))
|
|
|
| def test_sampler_no_shuffle(self):
|
| dataset = ["a", "b", "c", "d", "e", "f", "g"]
|
| sampler = RepeatSampler(dataset, mini_repeat_count=2, shuffle=False)
|
| sampled = list(sampler)
|
| expected = [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6]
|
| assert sampled == expected
|
|
|
| def test_sampler_no_repeat(self):
|
| dataset = ["a", "b", "c", "d", "e", "f", "g"]
|
| sampler = RepeatSampler(dataset, mini_repeat_count=1)
|
|
|
| sampled = list(sampler)
|
|
|
| assert len(sampled) == len(dataset)
|
|
|
| assert set(sampled) == set(range(len(dataset)))
|
|
|
| def test_sampler_with_batch_size(self):
|
| dataset = ["a", "b", "c", "d", "e", "f", "g", "h"]
|
| sampler = RepeatSampler(dataset, mini_repeat_count=1, batch_size=2, repeat_count=2)
|
|
|
| sampled = list(sampler)
|
|
|
| assert len(sampled) == 2 * len(dataset)
|
|
|
| assert set(sampled) == set(range(len(dataset)))
|
|
|
| assert all(sampled[i : i + 1] == sampled[i + 2 : i + 3] for i in range(0, len(sampled), 4))
|
|
|
| def test_sampler_with_batch_size_and_drop(self):
|
| dataset = ["a", "b", "c", "d", "e", "f", "g"]
|
| sampler = RepeatSampler(dataset, mini_repeat_count=1, batch_size=2, repeat_count=2)
|
|
|
| sampled = list(sampler)
|
|
|
| assert len(sampled) == 2 * (
|
| len(dataset) - 1
|
| )
|
| assert len(sampler) == len(sampled)
|
|
|
| assert set(sampled).issubset(set(range(len(dataset))))
|
|
|
| assert all(sampled[i : i + 1] == sampled[i + 2 : i + 3] for i in range(0, len(sampled), 4))
|
|
|
| def test_sampler_with_mini_repeat_count_and_batch_size_1(self):
|
| dataset = ["a", "b", "c", "d", "e", "f", "g"]
|
| sampler = RepeatSampler(dataset, mini_repeat_count=2, batch_size=3, repeat_count=2)
|
|
|
|
|
| sampled = list(sampler)
|
|
|
| assert len(sampled) == 4 * (len(dataset) - 1)
|
| assert len(sampler) == len(sampled)
|
|
|
| assert set(sampled).issubset(set(range(len(dataset))))
|
|
|
| assert all(sampled[i] == sampled[i + 1] for i in range(0, len(sampled), 2))
|
|
|
| assert sampled[0:6] == sampled[6:12]
|
| assert sampled[12:18] == sampled[18:24]
|
|
|
| def test_sampler_with_mini_repeat_count_and_batch_size_2(self):
|
| dataset = ["a", "b", "c", "d", "e", "f", "g"]
|
| sampler = RepeatSampler(dataset, mini_repeat_count=3, batch_size=2, repeat_count=2)
|
|
|
|
|
|
|
| sampled = list(sampler)
|
|
|
| assert len(sampled) == 6 * (len(dataset) - 1)
|
| assert len(sampler) == len(sampled)
|
|
|
| assert set(sampled).issubset(set(range(len(dataset))))
|
|
|
| assert all(sampled[i] == sampled[i + 1] == sampled[i + 2] for i in range(0, len(sampled), 3))
|
|
|
| assert sampled[0:6] == sampled[6:12]
|
| assert sampled[12:18] == sampled[18:24]
|
| assert sampled[24:30] == sampled[30:36]
|
|
|
| def test_sampler_with_mini_repeat_count_and_batch_size_3(self):
|
| dataset = ["a", "b", "c", "d", "e", "f", "g"]
|
| sampler = RepeatSampler(dataset, mini_repeat_count=2, batch_size=2, repeat_count=3)
|
|
|
|
|
|
|
| sampled = list(sampler)
|
|
|
| assert len(sampled) == 6 * (len(dataset) - 1)
|
|
|
| assert set(sampled).issubset(set(range(len(dataset))))
|
|
|
| assert all(sampled[i] == sampled[i + 1] for i in range(0, len(sampled), 2))
|
|
|
| assert sampled[0:4] == sampled[4:8] == sampled[8:12]
|
| assert sampled[12:16] == sampled[16:20] == sampled[20:24]
|
| assert sampled[24:28] == sampled[28:32] == sampled[32:36]
|
|
|
|
|
| class TestEntropyFromLogits(TrlTestCase):
|
| @pytest.mark.parametrize("shape", [(768,), (32, 768), (8, 16, 768), (2, 4, 8, 768)])
|
| @pytest.mark.parametrize("chunk_size", [1, 16])
|
| @pytest.mark.parametrize("dtype", [torch.float64, torch.float32, torch.float16, torch.bfloat16])
|
| def test_entropy_from_logits_2_dims(self, dtype, chunk_size, shape):
|
| logits = torch.randn(*shape, dtype=dtype)
|
| if dtype in (torch.float64, torch.float32):
|
| p = logits.softmax(-1)
|
| entropy = -torch.sum(p * p.log(), dim=-1)
|
| else:
|
| logps = logits.log_softmax(dim=-1)
|
| entropy = -(torch.exp(logps) * logps).sum(-1)
|
| predicted_entropy = entropy_from_logits(logits, chunk_size=chunk_size)
|
| torch.testing.assert_close(predicted_entropy, entropy, rtol=1e-5, atol=1e-5)
|
|
|
|
|
| @require_rich
|
| class TestPrintPromptCompletionsSample(TrlTestCase):
|
| @patch("sys.stdout", new_callable=StringIO)
|
| def test_print_output(self, mock_stdout):
|
| prompts = ["The sky is", "The sun is"]
|
| completions = [" blue.", " in the sky."]
|
| rewards = {"Correctness": [0.123, 0.456], "Format": [0.789, 0.101]}
|
| advantages = [0.987, 0.654]
|
| step = 42
|
|
|
| print_prompt_completions_sample(prompts, completions, rewards, advantages, step)
|
|
|
| output = mock_stdout.getvalue()
|
|
|
|
|
| expected_output = textwrap.dedent("""\
|
| ╭──────────────────────────── Step 42 ─────────────────────────────╮
|
| │ ┏━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━┓ │
|
| │ ┃ Prompt ┃ Completion ┃ Correctness ┃ Format ┃ Advantage ┃ │
|
| │ ┡━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━┩ │
|
| │ │ The sky is │ blue. │ 0.12 │ 0.79 │ 0.99 │ │
|
| │ ├────────────┼──────────────┼─────────────┼────────┼───────────┤ │
|
| │ │ The sun is │ in the sky. │ 0.46 │ 0.10 │ 0.65 │ │
|
| │ └────────────┴──────────────┴─────────────┴────────┴───────────┘ │
|
| ╰──────────────────────────────────────────────────────────────────╯
|
| """)
|
|
|
| assert output == expected_output
|
|
|
| @patch("sys.stdout", new_callable=StringIO)
|
| def test_extra_columns(self, mock_stdout):
|
| prompts = ["The sky is", "The sun is"]
|
| completions = [" blue.", " in the sky."]
|
| rewards = {"Correctness": [0.123, 0.456], "Format": [0.789, 0.101]}
|
| advantages = [0.987, 0.654]
|
| extra = {"source": ["dataset_A", "dataset_B"]}
|
| step = 42
|
|
|
| print_prompt_completions_sample(prompts, completions, rewards, advantages, step, extra=extra)
|
|
|
| output = mock_stdout.getvalue()
|
|
|
|
|
| expected_output = textwrap.dedent("""\
|
| ╭────────────────────────────────── Step 42 ───────────────────────────────────╮
|
| │ ┏━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━┓ │
|
| │ ┃ Prompt ┃ Completion ┃ Correctness ┃ Format ┃ Advantage ┃ source ┃ │
|
| │ ┡━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━┩ │
|
| │ │ The sky is │ blue. │ 0.12 │ 0.79 │ 0.99 │ dataset_A │ │
|
| │ ├────────────┼──────────────┼─────────────┼────────┼───────────┼───────────┤ │
|
| │ │ The sun is │ in the sky. │ 0.46 │ 0.10 │ 0.65 │ dataset_B │ │
|
| │ └────────────┴──────────────┴─────────────┴────────┴───────────┴───────────┘ │
|
| ╰──────────────────────────────────────────────────────────────────────────────╯
|
| """)
|
|
|
| assert output == expected_output
|
|
|
| @patch("sys.stdout", new_callable=StringIO)
|
| def test_num_samples(self, mock_stdout):
|
| prompts = ["A", "B"]
|
| completions = ["1", "2"]
|
| rewards = {"Score": [0.1, 0.2]}
|
| advantages = [0.3, 0.4]
|
| step = 10
|
|
|
| print_prompt_completions_sample(prompts, completions, rewards, advantages, step, num_samples=1)
|
| output = mock_stdout.getvalue()
|
|
|
|
|
| possible_outputs = [
|
| textwrap.dedent("""\
|
| ╭────────────────── Step 10 ──────────────────╮
|
| │ ┏━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━┓ │
|
| │ ┃ Prompt ┃ Completion ┃ Score ┃ Advantage ┃ │
|
| │ ┡━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━┩ │
|
| │ │ A │ 1 │ 0.10 │ 0.30 │ │
|
| │ └────────┴────────────┴───────┴───────────┘ │
|
| ╰─────────────────────────────────────────────╯
|
| """),
|
|
|
| textwrap.dedent("""\
|
| ╭────────────────── Step 10 ──────────────────╮
|
| │ ┏━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━┓ │
|
| │ ┃ Prompt ┃ Completion ┃ Score ┃ Advantage ┃ │
|
| │ ┡━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━┩ │
|
| │ │ B │ 2 │ 0.20 │ 0.40 │ │
|
| │ └────────┴────────────┴───────┴───────────┘ │
|
| ╰─────────────────────────────────────────────╯
|
| """),
|
| ]
|
| assert output in possible_outputs
|
|
|
| @patch("sys.stdout", new_callable=StringIO)
|
| def test_print_messages(self, mock_stdout):
|
| prompts = [
|
| [
|
| {"role": "system", "content": "You are an helpful assistant."},
|
| {"role": "user", "content": "What color is the sky?"},
|
| ],
|
| [
|
| {"role": "system", "content": "You are an helpful assistant."},
|
| {"role": "user", "content": "Where is the sun?"},
|
| ],
|
| ]
|
| completions = [
|
| [{"role": "assistant", "content": "It is blue."}],
|
| [{"role": "assistant", "content": "In the sky."}],
|
| ]
|
| rewards = {"Correctness": [0.123, 0.456], "Format": [0.789, 0.101]}
|
| advantages = [0.987, 0.654]
|
| step = 42
|
|
|
| print_prompt_completions_sample(prompts, completions, rewards, advantages, step)
|
|
|
| output = mock_stdout.getvalue()
|
|
|
|
|
| expected_output = textwrap.dedent("""\
|
| ╭────────────────────────────────── Step 42 ───────────────────────────────────╮
|
| │ ┏━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━┓ │
|
| │ ┃ Prompt ┃ Completion ┃ Correctness ┃ Format ┃ Advantage ┃ │
|
| │ ┡━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━┩ │
|
| │ │ SYSTEM │ ASSISTANT │ 0.12 │ 0.79 │ 0.99 │ │
|
| │ │ You are an helpful │ It is blue. │ │ │ │ │
|
| │ │ assistant. │ │ │ │ │ │
|
| │ │ │ │ │ │ │ │
|
| │ │ USER │ │ │ │ │ │
|
| │ │ What color is the sky? │ │ │ │ │ │
|
| │ ├─────────────────────────┼─────────────┼─────────────┼────────┼───────────┤ │
|
| │ │ SYSTEM │ ASSISTANT │ 0.46 │ 0.10 │ 0.65 │ │
|
| │ │ You are an helpful │ In the sky. │ │ │ │ │
|
| │ │ assistant. │ │ │ │ │ │
|
| │ │ │ │ │ │ │ │
|
| │ │ USER │ │ │ │ │ │
|
| │ │ Where is the sun? │ │ │ │ │ │
|
| │ └─────────────────────────┴─────────────┴─────────────┴────────┴───────────┘ │
|
| ╰──────────────────────────────────────────────────────────────────────────────╯
|
| """)
|
|
|
| assert output == expected_output
|
|
|
| @patch("sys.stdout", new_callable=StringIO)
|
| def test_print_messages_with_tools(self, mock_stdout):
|
| prompts = [
|
| [{"role": "user", "content": "What is the temperature in Paris?"}],
|
| [{"role": "user", "content": "What is the weather in London?"}],
|
| ]
|
| completions = [
|
| [{"role": "tool", "name": "get_temperature", "args": {"location": "Paris"}}],
|
| [{"role": "tool", "name": "get_weather", "args": {"location": "London"}}],
|
| ]
|
| rewards = {"Correctness": [0.123, 0.456], "Format": [0.789, 0.101]}
|
| advantages = [0.987, 0.654]
|
| step = 42
|
|
|
| print_prompt_completions_sample(prompts, completions, rewards, advantages, step)
|
|
|
| output = mock_stdout.getvalue()
|
|
|
|
|
| expected_output = textwrap.dedent("""\
|
| ╭────────────────────────────────── Step 42 ───────────────────────────────────╮
|
| │ ┏━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━┓ │
|
| │ ┃ Prompt ┃ Completion ┃ Correctness ┃ Format ┃ Advantage ┃ │
|
| │ ┡━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━┩ │
|
| │ │ USER │ TOOL │ 0.12 │ 0.79 │ 0.99 │ │
|
| │ │ What is the │ get_temperature(… │ │ │ │ │
|
| │ │ temperature in │ 'Paris'}) │ │ │ │ │
|
| │ │ Paris? │ │ │ │ │ │
|
| │ ├───────────────────┼───────────────────┼─────────────┼────────┼───────────┤ │
|
| │ │ USER │ TOOL │ 0.46 │ 0.10 │ 0.65 │ │
|
| │ │ What is the │ get_weather({'lo… │ │ │ │ │
|
| │ │ weather in │ 'London'}) │ │ │ │ │
|
| │ │ London? │ │ │ │ │ │
|
| │ └───────────────────┴───────────────────┴─────────────┴────────┴───────────┘ │
|
| ╰──────────────────────────────────────────────────────────────────────────────╯
|
| """)
|
|
|
| assert output == expected_output
|
|
|
| @patch("sys.stdout", new_callable=StringIO)
|
| def test_print_messages_with_reasoning_content(self, mock_stdout):
|
| prompts = [[{"role": "user", "content": "What color is the sky?"}]]
|
| completions = [[{"role": "assistant", "reasoning_content": "I think it is blue.", "content": "It is blue."}]]
|
| rewards = {"Score": [0.5]}
|
| advantages = [0.9]
|
| step = 1
|
|
|
| print_prompt_completions_sample(prompts, completions, rewards, advantages, step)
|
|
|
| output = mock_stdout.getvalue()
|
|
|
|
|
| expected_output = textwrap.dedent("""\
|
| ╭─────────────────────────────── Step 1 ───────────────────────────────╮
|
| │ ┏━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━┓ │
|
| │ ┃ Prompt ┃ Completion ┃ Score ┃ Advantage ┃ │
|
| │ ┡━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━┩ │
|
| │ │ USER │ ASSISTANT │ 0.50 │ 0.90 │ │
|
| │ │ What color is the sky? │ I think it is blue. │ │ │ │
|
| │ │ │ It is blue. │ │ │ │
|
| │ └────────────────────────┴─────────────────────┴───────┴───────────┘ │
|
| ╰──────────────────────────────────────────────────────────────────────╯
|
| """)
|
|
|
| assert output == expected_output
|
|
|
| @patch("sys.stdout", new_callable=StringIO)
|
| def test_print_messages_with_thinking(self, mock_stdout):
|
| prompts = [[{"role": "user", "content": "What color is the sky?"}]]
|
| completions = [[{"role": "assistant", "thinking": "I think it is blue.", "content": "It is blue."}]]
|
| rewards = {"Score": [0.5]}
|
| advantages = [0.9]
|
| step = 1
|
|
|
| print_prompt_completions_sample(prompts, completions, rewards, advantages, step)
|
|
|
| output = mock_stdout.getvalue()
|
|
|
|
|
| expected_output = textwrap.dedent("""\
|
| ╭─────────────────────────────── Step 1 ───────────────────────────────╮
|
| │ ┏━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━┓ │
|
| │ ┃ Prompt ┃ Completion ┃ Score ┃ Advantage ┃ │
|
| │ ┡━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━┩ │
|
| │ │ USER │ ASSISTANT │ 0.50 │ 0.90 │ │
|
| │ │ What color is the sky? │ I think it is blue. │ │ │ │
|
| │ │ │ It is blue. │ │ │ │
|
| │ └────────────────────────┴─────────────────────┴───────┴───────────┘ │
|
| ╰──────────────────────────────────────────────────────────────────────╯
|
| """)
|
|
|
| assert output == expected_output
|
|
|
|
|
| class TestSelectiveLogSoftmax(TrlTestCase):
|
| @pytest.mark.parametrize("dtype", [torch.float64, torch.float32, torch.float16, torch.bfloat16])
|
| def test_selective_log_softmax(self, dtype):
|
| """Test selective_log_softmax with logits of different dtypes"""
|
| vocab_size = 1024
|
| batch_size = 4
|
| seq_len = 32
|
|
|
| input_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len))
|
| logits = torch.randn(batch_size, seq_len, vocab_size, dtype=dtype)
|
|
|
| expected_output = torch.gather(logits.log_softmax(-1), dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
|
| actual_output = selective_log_softmax(logits, input_ids)
|
|
|
| if dtype in [torch.float16, torch.bfloat16]:
|
|
|
| assert torch.equal(actual_output, expected_output)
|
| else:
|
| torch.testing.assert_close(actual_output, expected_output, rtol=1e-5, atol=1e-5)
|
|
|
| @pytest.mark.parametrize("dtype", [torch.float64, torch.float32, torch.float16, torch.bfloat16])
|
| @pytest.mark.parametrize("k", [1, 8])
|
| def test_selective_log_softmax_multi_index(self, dtype, k):
|
| """Test selective_log_softmax with logits of different dtypes and index widths"""
|
| vocab_size = 1024
|
| batch_size = 4
|
| seq_len = 32
|
|
|
| index = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len, k))
|
| logits = torch.randn(batch_size, seq_len, vocab_size, dtype=dtype)
|
|
|
| expected_output = torch.gather(logits.log_softmax(-1), dim=-1, index=index)
|
| actual_output = selective_log_softmax(logits, index)
|
|
|
| assert actual_output.shape == (batch_size, seq_len, k)
|
| if dtype in [torch.float16, torch.bfloat16]:
|
|
|
| assert torch.equal(actual_output, expected_output)
|
| else:
|
| torch.testing.assert_close(actual_output, expected_output, rtol=1e-5, atol=1e-5)
|
|
|
|
|
| class TestShuffleSequenceDict(TrlTestCase):
|
| def test_shuffle_preserves_shape(self):
|
| x = torch.arange(6).reshape(3, 2)
|
| y = torch.arange(3).reshape(3, 1)
|
| tensor_dict = {"x": x.clone(), "y": y.clone()}
|
|
|
| shuffled = shuffle_sequence_dict(tensor_dict)
|
|
|
| assert shuffled["x"].shape == x.shape
|
| assert shuffled["y"].shape == y.shape
|
|
|
| def test_shuffle_consistent_across_tensors(self):
|
|
|
| x = torch.tensor([[10, 11], [20, 21], [30, 31]])
|
| y = torch.tensor([[1], [2], [3]])
|
| tensor_dict = {"x": x.clone(), "y": y.clone()}
|
|
|
| shuffled = shuffle_sequence_dict(tensor_dict)
|
|
|
|
|
| for i in range(3):
|
| x_row = shuffled["x"][i]
|
| y_val = shuffled["y"][i].item()
|
|
|
| if torch.equal(x_row, torch.tensor([10, 11])):
|
| assert y_val == 1
|
| elif torch.equal(x_row, torch.tensor([20, 21])):
|
| assert y_val == 2
|
| elif torch.equal(x_row, torch.tensor([30, 31])):
|
| assert y_val == 3
|
| else:
|
| pytest.fail("Unexpected x row in shuffled output.")
|
|
|
| def test_none_tensor_remains_none(self):
|
| x = torch.arange(6).reshape(3, 2)
|
| tensor_dict = {"x": x.clone(), "y": None}
|
|
|
| shuffled = shuffle_sequence_dict(tensor_dict)
|
|
|
| assert shuffled["y"] is None
|
| assert shuffled["x"].shape == x.shape
|
|
|
| def test_shuffle_with_list(self):
|
| x = torch.tensor([[10, 11], [20, 21], [30, 31]])
|
| y = ["a", "b", "c"]
|
|
|
| sequence_dict = {"x": x.clone(), "y": y}
|
|
|
| shuffled = shuffle_sequence_dict(sequence_dict)
|
|
|
|
|
| for i in range(3):
|
| x_row = shuffled["x"][i]
|
| y_val = shuffled["y"][i]
|
|
|
| if torch.equal(x_row, torch.tensor([10, 11])):
|
| assert y_val == "a"
|
| elif torch.equal(x_row, torch.tensor([20, 21])):
|
| assert y_val == "b"
|
| elif torch.equal(x_row, torch.tensor([30, 31])):
|
| assert y_val == "c"
|
| else:
|
| pytest.fail("Unexpected x row in shuffled output.")
|
|
|
|
|
| class TestSplitTensorDict(TrlTestCase):
|
| def test_split_equal_chunks(self):
|
| x = torch.arange(12).reshape(6, 2)
|
| y = torch.arange(6).reshape(6, 1)
|
| tensor_dict = {"x": x, "y": y}
|
|
|
| result = split_tensor_dict(tensor_dict, 3)
|
|
|
| expected_x_chunks = torch.chunk(x, 3, dim=0)
|
| expected_y_chunks = torch.chunk(y, 3, dim=0)
|
| assert len(result) == 3
|
| for i in range(3):
|
| assert torch.equal(result[i]["x"], expected_x_chunks[i])
|
| assert torch.equal(result[i]["y"], expected_y_chunks[i])
|
|
|
| def test_with_none_tensor(self):
|
| x = torch.arange(12).reshape(6, 2)
|
| tensor_dict = {"x": x, "y": None}
|
|
|
| result = split_tensor_dict(tensor_dict, 2)
|
|
|
| expected_x_chunks = torch.chunk(x, 2, dim=0)
|
| assert len(result) == 2
|
| for i in range(2):
|
| assert torch.equal(result[i]["x"], expected_x_chunks[i])
|
| assert result[i]["y"] is None
|
|
|
| def test_with_scalar(self):
|
| x = torch.arange(12).reshape(6, 2)
|
| tensor_dict = {"x": x, "y": torch.tensor(1)}
|
|
|
| result = split_tensor_dict(tensor_dict, 2)
|
|
|
| expected_x_chunks = torch.chunk(x, 2, dim=0)
|
| assert len(result) == 2
|
| for i in range(2):
|
| assert torch.equal(result[i]["x"], expected_x_chunks[i])
|
| assert torch.equal(result[i]["y"], torch.tensor(1))
|
|
|
|
|
| class TestSplitPixelValuesByGrid(TrlTestCase):
|
| def test_split_correctly_0(self):
|
| batch = {
|
| "image_grid_thw": torch.tensor([[1, 2, 2], [1, 2, 2]]),
|
| "num_images": [1, 1],
|
| "pixel_values": torch.arange(8 * 3).reshape(8, 3),
|
| }
|
| result = split_pixel_values_by_grid(batch)
|
| assert isinstance(result["pixel_values"], list)
|
| assert len(result["pixel_values"]) == 2
|
| assert torch.equal(result["pixel_values"][0], batch["pixel_values"][:4])
|
| assert torch.equal(result["pixel_values"][1], batch["pixel_values"][4:])
|
| assert isinstance(result["image_grid_thw"], list)
|
| assert len(result["image_grid_thw"]) == 2
|
| assert torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 2, 2]]))
|
| assert torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2]]))
|
|
|
| def test_split_correctly_1(self):
|
| batch = {
|
| "image_grid_thw": torch.tensor([[1, 2, 2], [1, 2, 4]]),
|
| "num_images": [1, 1],
|
| "pixel_values": torch.arange(12 * 3).reshape(12, 3),
|
| }
|
| result = split_pixel_values_by_grid(batch)
|
| assert isinstance(result["pixel_values"], list)
|
| assert len(result["pixel_values"]) == 2
|
| assert torch.equal(result["pixel_values"][0], batch["pixel_values"][:4])
|
| assert torch.equal(result["pixel_values"][1], batch["pixel_values"][4:12])
|
| assert isinstance(result["image_grid_thw"], list)
|
| assert len(result["image_grid_thw"]) == 2
|
| assert torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 2, 2]]))
|
| assert torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 4]]))
|
|
|
| def test_missing_keys(self):
|
| batch = {"pixel_values": torch.tensor([1.0])}
|
| result = split_pixel_values_by_grid(batch)
|
| assert result == batch
|
|
|
| def test_mismatched_length(self):
|
| batch = {
|
| "image_grid_thw": torch.tensor([[1, 1, 2], [1, 2, 1]]),
|
| "num_images": [1, 1],
|
| "pixel_values": torch.randn(3, 5),
|
| }
|
| with pytest.raises(ValueError):
|
| split_pixel_values_by_grid(batch)
|
|
|
| def test_multi_images(self):
|
| batch = {
|
| "image_grid_thw": torch.tensor([[1, 1, 2], [1, 2, 2], [1, 2, 1]]),
|
| "num_images": [1, 2],
|
| "pixel_values": torch.arange(8 * 3).reshape(8, 3),
|
| }
|
| result = split_pixel_values_by_grid(batch)
|
| assert isinstance(result["pixel_values"], list)
|
| assert len(result["pixel_values"]) == 2
|
| assert torch.equal(result["pixel_values"][0], batch["pixel_values"][:2])
|
| assert torch.equal(result["pixel_values"][1], batch["pixel_values"][2:])
|
| assert isinstance(result["image_grid_thw"], list)
|
| assert len(result["image_grid_thw"]) == 2
|
| assert torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 1, 2]]))
|
| assert torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2], [1, 2, 1]]))
|
|
|
| def test_split_by_image_position_ids(self):
|
|
|
| batch = {
|
| "num_images": [1, 2],
|
| "pixel_values": torch.arange(3 * 4).reshape(3, 4),
|
| "image_position_ids": torch.tensor([[0, 1], [2, 3], [4, 5]]),
|
| }
|
| result = split_pixel_values_by_grid(batch)
|
| assert isinstance(result["pixel_values"], list)
|
| assert len(result["pixel_values"]) == 2
|
| assert torch.equal(result["pixel_values"][0], batch["pixel_values"][:1])
|
| assert torch.equal(result["pixel_values"][1], batch["pixel_values"][1:])
|
| assert isinstance(result["image_position_ids"], list)
|
| assert len(result["image_position_ids"]) == 2
|
| assert torch.equal(result["image_position_ids"][0], batch["image_position_ids"][:1])
|
| assert torch.equal(result["image_position_ids"][1], batch["image_position_ids"][1:])
|
|
|
|
|
| class TestUnsplitPixelValuesByGrid(TrlTestCase):
|
| def test_unsplit_correctly(self):
|
| pixel_values = [torch.randn(4, 5), torch.randn(2, 5)]
|
| pixel_values_merged = torch.cat(pixel_values, dim=0)
|
| image_grid_thw = [torch.tensor([[1, 2, 2]]), torch.tensor([[1, 2, 1]])]
|
| image_grid_thw_merged = torch.cat(image_grid_thw, dim=0)
|
| batch = {"pixel_values": pixel_values, "image_grid_thw": image_grid_thw, "other_key": torch.tensor([1])}
|
| result = unsplit_pixel_values_by_grid(batch)
|
| assert isinstance(result["pixel_values"], torch.Tensor)
|
| torch.testing.assert_close(result["pixel_values"], pixel_values_merged)
|
| assert isinstance(result["image_grid_thw"], torch.Tensor)
|
| assert torch.equal(result["image_grid_thw"], image_grid_thw_merged)
|
| assert "other_key" in result
|
|
|
| def test_unsplit_image_position_ids(self):
|
| image_position_ids = [torch.tensor([[0, 1]]), torch.tensor([[2, 3], [4, 5]])]
|
| image_position_ids_merged = torch.cat(image_position_ids, dim=0)
|
| pixel_values = [torch.randn(1, 4), torch.randn(2, 4)]
|
| batch = {"pixel_values": pixel_values, "image_position_ids": image_position_ids}
|
| result = unsplit_pixel_values_by_grid(batch)
|
| assert isinstance(result["image_position_ids"], torch.Tensor)
|
| assert torch.equal(result["image_position_ids"], image_position_ids_merged)
|
|
|
| def test_no_op_if_not_list(self):
|
| original = torch.randn(5, 3)
|
| batch = {"pixel_values": original}
|
| result = unsplit_pixel_values_by_grid(batch)
|
| assert torch.equal(result["pixel_values"], original)
|
|
|
|
|
| class TestChunkedLogProbFunction:
|
| N, H, V = 64, 32, 128
|
| CHUNK_SIZE = 32
|
|
|
| def _reference_logprobs_and_entropy(self, hidden, weight, labels, temperature):
|
| logits = (hidden @ weight.t()).to(torch.float32) / temperature
|
| log_p = F.log_softmax(logits, dim=-1)
|
| logprobs = log_p.gather(-1, labels.unsqueeze(-1)).squeeze(-1)
|
| p = torch.softmax(logits, dim=-1)
|
| entropy = -(p * log_p).sum(dim=-1)
|
| return logprobs, entropy
|
|
|
| @pytest.mark.parametrize("temperature", [1.0, 0.7])
|
| def test_forward(self, temperature):
|
| torch.manual_seed(42)
|
| hidden = torch.randn(self.N, self.H)
|
| weight = torch.randn(self.V, self.H)
|
| labels = torch.randint(0, self.V, (self.N,))
|
|
|
| logprobs_chunked, entropy_chunked = _ChunkedLogProbFunction.apply(
|
| hidden, weight, labels, temperature, self.CHUNK_SIZE
|
| )
|
| logprobs_ref, entropy_ref = self._reference_logprobs_and_entropy(hidden, weight, labels, temperature)
|
|
|
| torch.testing.assert_close(logprobs_chunked, logprobs_ref, atol=1e-5, rtol=1e-5)
|
| torch.testing.assert_close(entropy_chunked, entropy_ref, atol=1e-5, rtol=1e-5)
|
|
|
| @pytest.mark.parametrize("temperature", [1.0, 0.7])
|
| def test_backward(self, temperature):
|
| torch.manual_seed(42)
|
| hidden = torch.randn(self.N, self.H, requires_grad=True)
|
| weight = torch.randn(self.V, self.H, requires_grad=True)
|
| labels = torch.randint(0, self.V, (self.N,))
|
|
|
|
|
| logprobs_chunked, _ = _ChunkedLogProbFunction.apply(hidden, weight, labels, temperature, self.CHUNK_SIZE)
|
| logprobs_chunked.sum().backward()
|
| grad_hidden_chunked = hidden.grad.clone()
|
| grad_weight_chunked = weight.grad.clone()
|
|
|
| hidden.grad = None
|
| weight.grad = None
|
|
|
|
|
| logprobs_ref, _ = self._reference_logprobs_and_entropy(hidden, weight, labels, temperature)
|
| logprobs_ref.sum().backward()
|
|
|
| torch.testing.assert_close(grad_hidden_chunked, hidden.grad, atol=1e-5, rtol=1e-5)
|
| torch.testing.assert_close(grad_weight_chunked, weight.grad, atol=1e-5, rtol=1e-5)
|
|
|
| @pytest.mark.parametrize("temperature", [1.0, 0.7])
|
| def test_backward_bfloat16(self, temperature):
|
| torch.manual_seed(42)
|
| hidden = torch.randn(self.N, self.H, dtype=torch.bfloat16, requires_grad=True)
|
| weight = torch.randn(self.V, self.H, dtype=torch.bfloat16, requires_grad=True)
|
| labels = torch.randint(0, self.V, (self.N,))
|
|
|
|
|
| logprobs_chunked, _ = _ChunkedLogProbFunction.apply(hidden, weight, labels, temperature, self.CHUNK_SIZE)
|
| logprobs_chunked.sum().backward()
|
| grad_hidden_chunked = hidden.grad.clone()
|
| grad_weight_chunked = weight.grad.clone()
|
|
|
| hidden.grad = None
|
| weight.grad = None
|
|
|
|
|
| logprobs_ref, _ = self._reference_logprobs_and_entropy(hidden, weight, labels, temperature)
|
| logprobs_ref.sum().backward()
|
|
|
| torch.testing.assert_close(grad_hidden_chunked, hidden.grad, atol=1e-2, rtol=1e-2)
|
| torch.testing.assert_close(grad_weight_chunked, weight.grad, atol=1e-2, rtol=1e-2)
|
|
|
|
|
| class _FakeTransformerModel(nn.Module):
|
| """Minimal stand-in for a transformer body: returns random hidden states of the right shape."""
|
|
|
| def __init__(self, hidden_size):
|
| super().__init__()
|
| self.hidden_size = hidden_size
|
| self._hidden = None
|
|
|
| def forward(self, input_ids, attention_mask=None, use_cache=False, **kwargs):
|
| b, s = input_ids.shape
|
| if self._hidden is None or self._hidden.shape[:2] != (b, s):
|
| torch.manual_seed(123)
|
| self._hidden = torch.randn(b, s, self.hidden_size, requires_grad=True)
|
| return type("Out", (), {"last_hidden_state": self._hidden})()
|
|
|
|
|
| class _FakeCausalLM(nn.Module):
|
| """Minimal CausalLM with .model and .lm_head, enough for patch_chunked_lm_head."""
|
|
|
| def __init__(self, hidden_size, vocab_size):
|
| super().__init__()
|
| self.config = type("Config", (), {})()
|
| self.model = _FakeTransformerModel(hidden_size)
|
| self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
|
|
|
| def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
|
| raise NotImplementedError("should be monkey-patched")
|
|
|
|
|
| _CHUNKED_LM_HEAD_MODEL_IDS = [
|
| "trl-internal-testing/tiny-CohereForCausalLM",
|
| "trl-internal-testing/tiny-Cohere2ForCausalLM",
|
| pytest.param(
|
| "trl-internal-testing/tiny-DeepseekV3ForCausalLM",
|
| marks=pytest.mark.skipif(
|
| Version(transformers.__version__) < Version("5.0.0"),
|
| reason="DeepseekV3 SDPA attention is broken in transformers < 5.0.0",
|
| ),
|
| ),
|
| pytest.param(
|
| "trl-internal-testing/tiny-DeepseekV3ForCausalLM-0528",
|
| marks=pytest.mark.skipif(
|
| Version(transformers.__version__) < Version("5.0.0"),
|
| reason="DeepseekV3 SDPA attention is broken in transformers < 5.0.0",
|
| ),
|
| ),
|
| "trl-internal-testing/tiny-Gemma2ForCausalLM",
|
| "trl-internal-testing/tiny-GemmaForCausalLM",
|
| "trl-internal-testing/tiny-Glm4MoeForCausalLM",
|
| "trl-internal-testing/tiny-GptOssForCausalLM",
|
| "trl-internal-testing/tiny-LlamaForCausalLM-3.1",
|
| "trl-internal-testing/tiny-LlamaForCausalLM-3.2",
|
| "trl-internal-testing/tiny-LlamaForCausalLM-3",
|
| "trl-internal-testing/tiny-MistralForCausalLM-0.1",
|
| "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
| "trl-internal-testing/tiny-Phi3ForCausalLM-3",
|
| "trl-internal-testing/tiny-Phi3ForCausalLM-3.5",
|
| "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
| "trl-internal-testing/tiny-Qwen3ForCausalLM",
|
| ]
|
|
|
|
|
| @require_torch_accelerator
|
| class TestPatchChunkedLMHead:
|
| B, S = 4, 16
|
| H, V = 32, 128
|
| CHUNK_SIZE = 32
|
|
|
| def _build_model_and_inputs(self, temperature=1.0):
|
| torch.manual_seed(42)
|
| model = _FakeCausalLM(self.H, self.V)
|
| patch_chunked_lm_head(model, self.CHUNK_SIZE, temperature)
|
|
|
| input_ids = torch.randint(0, self.V, (self.B, self.S))
|
| attention_mask = torch.ones(self.B, self.S, dtype=torch.long)
|
|
|
| completion_mask = torch.zeros(self.B, self.S, dtype=torch.float32)
|
| completion_mask[:, self.S // 2 :] = 1.0
|
| return model, input_ids, attention_mask, completion_mask
|
|
|
| @pytest.mark.parametrize("temperature", [1.0, 0.7])
|
| def test_dummy_model_chunked_forward_with_completion_mask(self, temperature):
|
| """Masked forward matches unmasked forward at completion positions and is zero at prompt positions."""
|
| model, input_ids, attention_mask, completion_mask = self._build_model_and_inputs(temperature)
|
|
|
|
|
| out_full = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
|
|
|
|
|
| model.model._hidden = None
|
|
|
|
|
| out_masked = model(
|
| input_ids=input_ids, attention_mask=attention_mask, labels=input_ids, completion_mask=completion_mask
|
| )
|
|
|
|
|
| shifted_mask = completion_mask[:, 1:].bool()
|
|
|
|
|
| torch.testing.assert_close(
|
| out_masked["log_probs"][shifted_mask],
|
| out_full["log_probs"][shifted_mask],
|
| atol=1e-5,
|
| rtol=1e-5,
|
| )
|
| torch.testing.assert_close(
|
| out_masked["entropy"][shifted_mask],
|
| out_full["entropy"][shifted_mask],
|
| atol=1e-5,
|
| rtol=1e-5,
|
| )
|
|
|
|
|
| prompt_mask = ~shifted_mask
|
| assert (out_masked["log_probs"][prompt_mask] == 0).all()
|
| assert (out_masked["entropy"][prompt_mask] == 0).all()
|
|
|
| @pytest.mark.parametrize("temperature", [1.0, 0.7])
|
| def test_dummy_model_chunked_forward_completion_mask_backward(self, temperature):
|
| model, input_ids, attention_mask, completion_mask = self._build_model_and_inputs(temperature)
|
|
|
|
|
| out_full = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
|
| shifted_mask = completion_mask[:, 1:]
|
| loss_full = (out_full["log_probs"] * shifted_mask).sum()
|
| loss_full.backward()
|
| grad_weight_full = model.lm_head.weight.grad.clone()
|
|
|
| model.lm_head.weight.grad = None
|
| model.model._hidden = None
|
|
|
|
|
| out_masked = model(
|
| input_ids=input_ids, attention_mask=attention_mask, labels=input_ids, completion_mask=completion_mask
|
| )
|
| loss_masked = (out_masked["log_probs"] * shifted_mask).sum()
|
| loss_masked.backward()
|
| grad_weight_masked = model.lm_head.weight.grad.clone()
|
|
|
| torch.testing.assert_close(grad_weight_masked, grad_weight_full, atol=1e-5, rtol=1e-5)
|
|
|
| @pytest.mark.parametrize("model_id", _CHUNKED_LM_HEAD_MODEL_IDS)
|
| @pytest.mark.parametrize("temperature", [1.0, 0.7])
|
| def test_forward(self, model_id, temperature):
|
| model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(torch_device)
|
| if getattr(model.config, "final_logit_softcapping", None) is not None:
|
| pytest.skip("model uses final_logit_softcapping, not supported by chunked LM head")
|
| model.eval()
|
|
|
| B, S, chunk_size = 2, 8, 32
|
| torch.manual_seed(42)
|
| input_ids = torch.randint(0, model.config.vocab_size, (B, S), device=torch_device)
|
| labels = input_ids.clone()
|
|
|
|
|
| with torch.no_grad():
|
| ref_logits = model(input_ids=input_ids).logits[:, :-1, :].float() / temperature
|
| shifted_labels = labels[:, 1:]
|
| ref_log_p = F.log_softmax(ref_logits, dim=-1)
|
| ref_logprobs = ref_log_p.gather(-1, shifted_labels.unsqueeze(-1)).squeeze(-1)
|
| ref_p = ref_logits.softmax(dim=-1)
|
| ref_entropy = -(ref_p * ref_log_p).sum(dim=-1)
|
|
|
|
|
| patch_chunked_lm_head(model, chunk_size, temperature)
|
| with torch.no_grad():
|
| out = model(input_ids=input_ids, labels=labels)
|
|
|
| torch.testing.assert_close(out["log_probs"], ref_logprobs, atol=5e-3, rtol=5e-3)
|
| torch.testing.assert_close(out["entropy"], ref_entropy, atol=5e-3, rtol=5e-3)
|
|
|
| @pytest.mark.parametrize("model_id", _CHUNKED_LM_HEAD_MODEL_IDS)
|
| @pytest.mark.parametrize("temperature", [1.0, 0.7])
|
| def test_backward(self, model_id, temperature):
|
| model_ref = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(torch_device)
|
| if getattr(model_ref.config, "final_logit_softcapping", None) is not None:
|
| pytest.skip("model uses final_logit_softcapping, not supported by chunked LM head")
|
| model_chunked = copy.deepcopy(model_ref)
|
|
|
| B, S, chunk_size = 2, 8, 32
|
| torch.manual_seed(42)
|
| input_ids = torch.randint(0, model_ref.config.vocab_size, (B, S), device=torch_device)
|
| labels = input_ids.clone()
|
| shifted_labels = labels[:, 1:]
|
|
|
|
|
| ref_logits = model_ref(input_ids=input_ids).logits[:, :-1, :].float() / temperature
|
| ref_log_p = F.log_softmax(ref_logits, dim=-1)
|
| ref_logprobs = ref_log_p.gather(-1, shifted_labels.unsqueeze(-1)).squeeze(-1)
|
| ref_logprobs.sum().backward()
|
| ref_grad = model_ref.lm_head.weight.grad.clone()
|
|
|
|
|
| patch_chunked_lm_head(model_chunked, chunk_size, temperature)
|
| out = model_chunked(input_ids=input_ids, labels=labels)
|
| out["log_probs"].sum().backward()
|
| chunked_grad = model_chunked.lm_head.weight.grad.clone()
|
|
|
| torch.testing.assert_close(chunked_grad, ref_grad, atol=5e-2, rtol=5e-2)
|
|
|