|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import json
|
| import os
|
| from unittest.mock import call, patch
|
|
|
| from datasets import load_dataset
|
| from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, Trainer, TrainingArguments
|
|
|
| from trl import BEMACallback, LogCompletionsCallback
|
|
|
| from .testing_utils import TrlTestCase, require_comet, require_wandb
|
|
|
|
|
| class TestLogCompletionsCallback(TrlTestCase):
|
| def setup_method(self):
|
| self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
|
| self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
|
| self.tokenizer.pad_token = self.tokenizer.eos_token
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only")
|
| dataset["train"] = dataset["train"].select(range(8))
|
|
|
| def tokenize_function(examples):
|
| out = self.tokenizer(examples["prompt"], padding="max_length", max_length=16, truncation=True)
|
| out["labels"] = out["input_ids"].copy()
|
| return out
|
|
|
| self.dataset = dataset.map(tokenize_function, batched=True)
|
|
|
| self.generation_config = GenerationConfig(max_length=32)
|
|
|
| @require_wandb
|
| def test_basic_wandb(self):
|
| import wandb
|
|
|
| training_args = TrainingArguments(
|
| output_dir=self.tmp_dir,
|
| eval_strategy="steps",
|
| eval_steps=2,
|
| per_device_train_batch_size=2,
|
| per_device_eval_batch_size=2,
|
| report_to="wandb",
|
| )
|
| trainer = Trainer(
|
| model=self.model,
|
| args=training_args,
|
| train_dataset=self.dataset["train"],
|
| eval_dataset=self.dataset["test"],
|
| processing_class=self.tokenizer,
|
| )
|
| completions_callback = LogCompletionsCallback(trainer, self.generation_config, num_prompts=2)
|
| trainer.add_callback(completions_callback)
|
| trainer.train()
|
|
|
|
|
| completions_path = wandb.run.summary.completions["path"]
|
| json_path = os.path.join(wandb.run.dir, completions_path)
|
| with open(json_path) as f:
|
| completions = json.load(f)
|
|
|
|
|
| assert "step" in completions["columns"]
|
| assert "prompt" in completions["columns"]
|
| assert "completion" in completions["columns"]
|
|
|
|
|
| assert self.dataset["test"][0]["prompt"] in completions["data"][0]
|
|
|
| @require_comet
|
| def test_basic_comet(self):
|
| import comet_ml
|
|
|
| training_args = TrainingArguments(
|
| output_dir=self.tmp_dir,
|
| eval_strategy="steps",
|
| eval_steps=2,
|
| per_device_train_batch_size=2,
|
| per_device_eval_batch_size=2,
|
| report_to="comet_ml",
|
| )
|
| trainer = Trainer(
|
| model=self.model,
|
| args=training_args,
|
| train_dataset=self.dataset["train"],
|
| eval_dataset=self.dataset["test"],
|
| processing_class=self.tokenizer,
|
| )
|
| completions_callback = LogCompletionsCallback(trainer, self.generation_config, num_prompts=2)
|
| trainer.add_callback(completions_callback)
|
| trainer.train()
|
|
|
|
|
| experiment = comet_ml.get_running_experiment()
|
| assert experiment is not None
|
| experiment.end()
|
|
|
|
|
| steps = len(self.dataset["train"]) + len(self.dataset["test"])
|
| tables_logged = int(steps / 2) + 1
|
|
|
| api_experiment = comet_ml.APIExperiment(previous_experiment=experiment.id)
|
| tables = api_experiment.get_asset_list("dataframe")
|
| assert tables is not None
|
| assert len(tables) == tables_logged
|
| assert all(table["fileName"] == "completions.csv" for table in tables)
|
|
|
|
|
| class TestBEMACallback(TrlTestCase):
|
| def setup_method(self):
|
| self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
|
| self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
|
| self.tokenizer.pad_token = self.tokenizer.eos_token
|
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling")
|
|
|
| def tokenize_function(examples, tokenizer):
|
| out = tokenizer(examples["text"], padding="max_length", max_length=17)
|
| out["labels"] = out["input_ids"].copy()
|
| return out
|
|
|
| self.dataset = dataset.map(
|
| tokenize_function, fn_kwargs={"tokenizer": self.tokenizer}, remove_columns=["text"], batched=True
|
| )
|
|
|
| def test_model_saved(self):
|
| """Test that BEMACallback saves the BEMA model."""
|
| training_args = TrainingArguments(output_dir=self.tmp_dir, report_to="none")
|
| bema_callback = BEMACallback(update_freq=2)
|
| trainer = Trainer(
|
| model=self.model,
|
| args=training_args,
|
| train_dataset=self.dataset["train"],
|
| processing_class=self.tokenizer,
|
| callbacks=[bema_callback],
|
| )
|
| trainer.train()
|
|
|
|
|
| bema_path = os.path.join(self.tmp_dir, "bema")
|
| assert os.path.isdir(bema_path), "BEMA directory was not created"
|
| AutoModelForCausalLM.from_pretrained(bema_path)
|
|
|
| def test_update_frequency_0(self):
|
| """Test that BEMA callback respects the update frequency."""
|
| training_args = TrainingArguments(output_dir=self.tmp_dir, report_to="none")
|
| bema_callback = BEMACallback(update_freq=2)
|
|
|
| with patch.object(bema_callback, "_update_bema_weights") as mock_update:
|
| trainer = Trainer(
|
| model=self.model,
|
| args=training_args,
|
| train_dataset=self.dataset["train"],
|
| processing_class=self.tokenizer,
|
| callbacks=[bema_callback],
|
| )
|
|
|
| trainer.train()
|
|
|
|
|
|
|
| assert mock_update.call_args_list == [call(2), call(4), call(6), call(8)]
|
|
|
| def test_update_frequency_1(self):
|
| """Test that BEMA callback respects the update frequency."""
|
| training_args = TrainingArguments(output_dir=self.tmp_dir, report_to="none")
|
| bema_callback = BEMACallback(update_freq=3)
|
|
|
| with patch.object(bema_callback, "_update_bema_weights") as mock_update:
|
| trainer = Trainer(
|
| model=self.model,
|
| args=training_args,
|
| train_dataset=self.dataset["train"],
|
| processing_class=self.tokenizer,
|
| callbacks=[bema_callback],
|
| )
|
|
|
| trainer.train()
|
|
|
|
|
|
|
| assert mock_update.call_args_list == [call(3), call(6), call(9)]
|
|
|
| def test_update_frequency_2(self):
|
| """Test that BEMA callback respects the update frequency."""
|
| training_args = TrainingArguments(output_dir=self.tmp_dir, report_to="none")
|
| bema_callback = BEMACallback(update_freq=2, update_after=3)
|
|
|
| with patch.object(bema_callback, "_update_bema_weights") as mock_update:
|
| trainer = Trainer(
|
| model=self.model,
|
| args=training_args,
|
| train_dataset=self.dataset["train"],
|
| processing_class=self.tokenizer,
|
| callbacks=[bema_callback],
|
| )
|
|
|
| trainer.train()
|
|
|
|
|
|
|
| assert mock_update.call_args_list == [call(5), call(7), call(9)]
|
|
|
| def test_no_bema(self):
|
| """Test that BEMACallback works without BEMA updates."""
|
| training_args = TrainingArguments(output_dir=self.tmp_dir, report_to="none")
|
| bema_callback = BEMACallback(update_freq=2, bias_power=0.0)
|
| trainer = Trainer(
|
| model=self.model,
|
| args=training_args,
|
| train_dataset=self.dataset["train"],
|
| processing_class=self.tokenizer,
|
| callbacks=[bema_callback],
|
| )
|
| trainer.train()
|
|
|
| def test_no_ema(self):
|
| """Test that BEMACallback works without EMA updates."""
|
| training_args = TrainingArguments(output_dir=self.tmp_dir, report_to="none")
|
| bema_callback = BEMACallback(update_freq=2, ema_power=0.0)
|
| trainer = Trainer(
|
| model=self.model,
|
| args=training_args,
|
| train_dataset=self.dataset["train"],
|
| processing_class=self.tokenizer,
|
| callbacks=[bema_callback],
|
| )
|
| trainer.train()
|
|
|