| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import sys |
| import unittest |
| from unittest.mock import patch |
|
|
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| from .testing_utils import is_peft_available, require_peft |
|
|
|
|
| class DummyDataset(torch.utils.data.Dataset): |
| def __init__(self, query_data, response_data): |
| self.query_data = query_data |
| self.response_data = response_data |
|
|
| def __len__(self): |
| return len(self.query_data) |
|
|
| def __getitem__(self, idx): |
| return self.query_data[idx], self.response_data[idx] |
|
|
|
|
| EXPECTED_STATS = [ |
| "objective/kl", |
| "objective/kl_dist", |
| "objective/logprobs", |
| "objective/ref_logprobs", |
| "objective/kl_coef", |
| "objective/entropy", |
| "ppo/mean_non_score_reward", |
| "ppo/loss/policy", |
| "ppo/loss/value", |
| "ppo/loss/total", |
| "ppo/policy/entropy", |
| "ppo/policy/approxkl", |
| "ppo/policy/policykl", |
| "ppo/policy/clipfrac", |
| "ppo/policy/advantages", |
| "ppo/policy/advantages_mean", |
| "ppo/policy/ratio", |
| "ppo/returns/mean", |
| "ppo/returns/var", |
| "ppo/val/vpred", |
| "ppo/val/error", |
| "ppo/val/clipfrac", |
| "ppo/val/mean", |
| "ppo/val/var", |
| "ppo/val/var_explained", |
| "time/ppo/forward_pass", |
| "time/ppo/compute_rewards", |
| "time/ppo/optimize_step", |
| "time/ppo/calc_stats", |
| "time/ppo/total", |
| "ppo/learning_rate", |
| ] |
|
|
|
|
| @require_peft |
| class TestPeftDependancy(unittest.TestCase): |
| def setUp(self): |
| self.causal_lm_model_id = "trl-internal-testing/tiny-random-GPTNeoXForCausalLM" |
| self.seq_to_seq_model_id = "trl-internal-testing/tiny-random-T5ForConditionalGeneration" |
|
|
| if is_peft_available(): |
| from peft import LoraConfig, get_peft_model |
|
|
| lora_config = LoraConfig( |
| r=16, |
| lora_alpha=32, |
| lora_dropout=0.05, |
| bias="none", |
| task_type="CAUSAL_LM", |
| ) |
|
|
| causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id) |
| self.peft_model = get_peft_model(causal_lm_model, lora_config) |
|
|
| def test_no_peft(self): |
| with patch.dict(sys.modules, {"peft": None}): |
| from trl import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead |
|
|
| |
| with self.assertRaises(ModuleNotFoundError): |
| import peft |
|
|
| trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.causal_lm_model_id) |
| trl_seq2seq_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(self.seq_to_seq_model_id) |
|
|
| def test_imports_no_peft(self): |
| with patch.dict(sys.modules, {"peft": None}): |
| from trl import ( |
| AutoModelForCausalLMWithValueHead, |
| AutoModelForSeq2SeqLMWithValueHead, |
| PPOConfig, |
| PPOTrainer, |
| PreTrainedModelWrapper, |
| ) |
|
|
| def test_ppo_trainer_no_peft(self): |
| with patch.dict(sys.modules, {"peft": None}): |
| from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer |
|
|
| ppo_model_id = "trl-internal-testing/dummy-GPT2-correct-vocab" |
|
|
| trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(ppo_model_id) |
| tokenizer = AutoTokenizer.from_pretrained(ppo_model_id) |
| tokenizer.pad_token_id = tokenizer.eos_token_id |
|
|
| ppo_config = PPOConfig(batch_size=2, mini_batch_size=1, log_with=None) |
|
|
| dummy_dataset = DummyDataset( |
| [torch.LongTensor([0, 1, 0, 1, 0, 1]), torch.LongTensor([0, 1, 0, 1, 0, 1])], |
| [torch.LongTensor([1, 0, 1, 0, 1, 0]), torch.LongTensor([0, 1, 0, 1, 0, 1])], |
| ) |
|
|
| ppo_trainer = PPOTrainer( |
| config=ppo_config, |
| model=trl_model, |
| ref_model=None, |
| tokenizer=tokenizer, |
| dataset=dummy_dataset, |
| ) |
| dummy_dataloader = ppo_trainer.dataloader |
|
|
| for query_tensor, response_tensor in dummy_dataloader: |
| |
| |
| reward = [torch.tensor(1.0), torch.tensor(0.0)] |
| |
| train_stats = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward) |
| break |
|
|
| |
| for _, param in trl_model.named_parameters(): |
| if param.requires_grad: |
| self.assertIsNotNone(param.grad) |
|
|
| |
| for stat in EXPECTED_STATS: |
| self.assertIn(stat, train_stats) |
|
|