trl-mcsd / tests /experimental /test_merge_model_callback.py
ihbkaiser's picture
Implement MCSD for experimental SDPO
1fa3c6c verified
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.trainer_utils import get_last_checkpoint
from trl import DPOConfig, DPOTrainer
from trl.experimental.merge_model_callback import MergeConfig, MergeModelCallback
from ..testing_utils import TrlTestCase, require_mergekit
@require_mergekit
class TestMergeModelCallback(TrlTestCase):
def setup_method(self):
self.model = AutoModelForCausalLM.from_pretrained(
"trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", dtype="float32"
)
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
self.dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
def test_callback(self):
training_args = DPOConfig(
output_dir=self.tmp_dir,
num_train_epochs=1,
report_to="none",
save_strategy="steps",
save_steps=1,
)
config = MergeConfig()
merge_callback = MergeModelCallback(config)
trainer = DPOTrainer(
model=self.model,
args=training_args,
train_dataset=self.dataset,
processing_class=self.tokenizer,
callbacks=[merge_callback],
)
trainer.train()
last_checkpoint = get_last_checkpoint(self.tmp_dir)
merged_path = os.path.join(last_checkpoint, "merged")
assert os.path.isdir(merged_path), "Merged folder does not exist in the last checkpoint."
def test_every_checkpoint(self):
training_args = DPOConfig(
output_dir=self.tmp_dir,
num_train_epochs=1,
report_to="none",
save_strategy="steps",
save_steps=1,
)
config = MergeConfig()
merge_callback = MergeModelCallback(config, merge_at_every_checkpoint=True)
trainer = DPOTrainer(
model=self.model,
args=training_args,
train_dataset=self.dataset,
processing_class=self.tokenizer,
callbacks=[merge_callback],
)
trainer.train()
checkpoints = sorted(
[os.path.join(self.tmp_dir, cp) for cp in os.listdir(self.tmp_dir) if cp.startswith("checkpoint-")]
)
for checkpoint in checkpoints:
merged_path = os.path.join(checkpoint, "merged")
assert os.path.isdir(merged_path), f"Merged folder does not exist in checkpoint {checkpoint}."