|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import pytest
|
| import torch
|
| from datasets import load_dataset
|
|
|
| from trl.experimental.dppo import DPPOConfig, DPPOTrainer
|
|
|
| from ..testing_utils import TrlTestCase
|
|
|
|
|
| class TestDPPODivergenceMask:
|
| """Unit tests for _compute_divergence_mask with synthetic inputs."""
|
|
|
| @staticmethod
|
| def make_trainer(divergence_type="binary_tv", epsilon=0.2, epsilon_high=0.28):
|
| """Create a minimal DPPOTrainer-like object with just the attributes needed for _compute_divergence_mask."""
|
|
|
| class Stub:
|
| pass
|
|
|
| stub = Stub()
|
| stub.divergence_type = divergence_type
|
| stub.epsilon_low = epsilon
|
| stub.epsilon_high = epsilon_high
|
| return stub
|
|
|
| @staticmethod
|
| def compute_divergence_mask(
|
| trainer_stub,
|
| current_logps,
|
| sampling_logps,
|
| advantages,
|
| completion_mask,
|
| current_topk_logps=None,
|
| sampling_topk_logps=None,
|
| ):
|
| return DPPOTrainer._compute_divergence_mask(
|
| trainer_stub,
|
| current_logps,
|
| sampling_logps,
|
| advantages,
|
| completion_mask,
|
| current_topk_logps=current_topk_logps,
|
| sampling_topk_logps=sampling_topk_logps,
|
| )
|
|
|
| def test_binary_tv_no_masking_within_threshold(self):
|
| stub = self.make_trainer("binary_tv", epsilon=0.2, epsilon_high=0.28)
|
|
|
| sampling_logps = torch.log(torch.tensor([[0.5, 0.3, 0.7]]))
|
| current_logps = torch.log(torch.tensor([[0.51, 0.29, 0.71]]))
|
| advantages = torch.tensor([[1.0]])
|
| completion_mask = torch.ones(1, 3)
|
|
|
| mask = self.compute_divergence_mask(stub, current_logps, sampling_logps, advantages, completion_mask)
|
| assert mask.shape == (1, 3)
|
| assert (mask == 1.0).all()
|
|
|
| def test_binary_tv_masks_positive_advantage_high_divergence(self):
|
| stub = self.make_trainer("binary_tv", epsilon=0.01, epsilon_high=0.01)
|
|
|
| sampling_logps = torch.log(torch.tensor([[0.1]]))
|
| current_logps = torch.log(torch.tensor([[0.5]]))
|
| advantages = torch.tensor([[1.0]])
|
| completion_mask = torch.ones(1, 1)
|
|
|
| mask = self.compute_divergence_mask(stub, current_logps, sampling_logps, advantages, completion_mask)
|
| assert mask.item() == 0.0
|
|
|
| def test_binary_tv_masks_negative_advantage_low_divergence(self):
|
| stub = self.make_trainer("binary_tv", epsilon=0.01, epsilon_high=0.01)
|
|
|
| sampling_logps = torch.log(torch.tensor([[0.5]]))
|
| current_logps = torch.log(torch.tensor([[0.1]]))
|
| advantages = torch.tensor([[-1.0]])
|
| completion_mask = torch.ones(1, 1)
|
|
|
| mask = self.compute_divergence_mask(stub, current_logps, sampling_logps, advantages, completion_mask)
|
| assert mask.item() == 0.0
|
|
|
| def test_binary_tv_respects_completion_mask(self):
|
| stub = self.make_trainer("binary_tv", epsilon=0.01, epsilon_high=0.01)
|
|
|
| sampling_logps = torch.log(torch.tensor([[0.1, 0.5]]))
|
| current_logps = torch.log(torch.tensor([[0.9, 0.9]]))
|
| advantages = torch.tensor([[1.0]])
|
| completion_mask = torch.tensor([[1.0, 0.0]])
|
|
|
| mask = self.compute_divergence_mask(stub, current_logps, sampling_logps, advantages, completion_mask)
|
| assert mask[0, 1].item() == 0.0
|
|
|
| def test_topk_tv_requires_topk_inputs(self):
|
| stub = self.make_trainer("topk_tv")
|
| B, T, K = 1, 2, 4
|
| sampling_logps = torch.log(torch.full((B, T), 0.3))
|
| current_logps = torch.log(torch.full((B, T), 0.31))
|
| advantages = torch.tensor([[1.0]])
|
| completion_mask = torch.ones(B, T)
|
|
|
|
|
| topk_probs = torch.softmax(torch.randn(B, T, K), dim=-1)
|
| sampling_topk_logps = torch.log(topk_probs)
|
| current_topk_logps = torch.log(topk_probs + 0.001)
|
|
|
| mask = self.compute_divergence_mask(
|
| stub,
|
| current_logps,
|
| sampling_logps,
|
| advantages,
|
| completion_mask,
|
| current_topk_logps=current_topk_logps,
|
| sampling_topk_logps=sampling_topk_logps,
|
| )
|
| assert mask.shape == (B, T)
|
| assert (mask == 1.0).all()
|
|
|
|
|
| @pytest.mark.low_priority
|
| class TestDPPOTrainer(TrlTestCase):
|
| @pytest.mark.parametrize("divergence_type", ["binary_tv", "binary_kl"])
|
| def test_training_binary(self, divergence_type):
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
|
|
| training_args = DPPOConfig(
|
| output_dir=self.tmp_dir,
|
| learning_rate=0.1,
|
| per_device_train_batch_size=3,
|
| num_generations=3,
|
| max_completion_length=8,
|
| divergence_type=divergence_type,
|
| report_to="none",
|
| )
|
| trainer = DPPOTrainer(
|
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
| reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
| args=training_args,
|
| train_dataset=dataset,
|
| )
|
|
|
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
| trainer.train()
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|
| for n, param in previous_trainable_params.items():
|
| new_param = trainer.model.get_parameter(n)
|
| assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
|
|
|
| @pytest.mark.parametrize("config_name", ["standard_prompt_only", "conversational_prompt_only"])
|
| def test_training_conversational(self, config_name):
|
| dataset = load_dataset("trl-internal-testing/zen", config_name, split="train")
|
|
|
| training_args = DPPOConfig(
|
| output_dir=self.tmp_dir,
|
| learning_rate=0.1,
|
| per_device_train_batch_size=3,
|
| num_generations=3,
|
| max_completion_length=8,
|
| report_to="none",
|
| )
|
| trainer = DPPOTrainer(
|
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
| reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
| args=training_args,
|
| train_dataset=dataset,
|
| )
|
|
|
| trainer.train()
|
|
|
| assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
|