| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import shutil |
| | import tempfile |
| | import unittest |
| | from unittest.mock import patch |
| |
|
| | from transformers import ( |
| | DefaultFlowCallback, |
| | IntervalStrategy, |
| | PrinterCallback, |
| | ProgressCallback, |
| | Trainer, |
| | TrainerCallback, |
| | TrainingArguments, |
| | is_torch_available, |
| | ) |
| | from transformers.testing_utils import require_torch |
| |
|
| |
|
| | if is_torch_available(): |
| | from transformers.trainer import DEFAULT_CALLBACKS |
| |
|
| | from .test_trainer import RegressionDataset, RegressionModelConfig, RegressionPreTrainedModel |
| |
|
| |
|
| | class MyTestTrainerCallback(TrainerCallback): |
| | "A callback that registers the events that goes through." |
| |
|
| | def __init__(self): |
| | self.events = [] |
| |
|
| | def on_init_end(self, args, state, control, **kwargs): |
| | self.events.append("on_init_end") |
| |
|
| | def on_train_begin(self, args, state, control, **kwargs): |
| | self.events.append("on_train_begin") |
| |
|
| | def on_train_end(self, args, state, control, **kwargs): |
| | self.events.append("on_train_end") |
| |
|
| | def on_epoch_begin(self, args, state, control, **kwargs): |
| | self.events.append("on_epoch_begin") |
| |
|
| | def on_epoch_end(self, args, state, control, **kwargs): |
| | self.events.append("on_epoch_end") |
| |
|
| | def on_step_begin(self, args, state, control, **kwargs): |
| | self.events.append("on_step_begin") |
| |
|
| | def on_step_end(self, args, state, control, **kwargs): |
| | self.events.append("on_step_end") |
| |
|
| | def on_evaluate(self, args, state, control, **kwargs): |
| | self.events.append("on_evaluate") |
| |
|
| | def on_predict(self, args, state, control, **kwargs): |
| | self.events.append("on_predict") |
| |
|
| | def on_save(self, args, state, control, **kwargs): |
| | self.events.append("on_save") |
| |
|
| | def on_log(self, args, state, control, **kwargs): |
| | self.events.append("on_log") |
| |
|
| | def on_prediction_step(self, args, state, control, **kwargs): |
| | self.events.append("on_prediction_step") |
| |
|
| |
|
| | @require_torch |
| | class TrainerCallbackTest(unittest.TestCase): |
| | def setUp(self): |
| | self.output_dir = tempfile.mkdtemp() |
| |
|
| | def tearDown(self): |
| | shutil.rmtree(self.output_dir) |
| |
|
| | def get_trainer(self, a=0, b=0, train_len=64, eval_len=64, callbacks=None, disable_tqdm=False, **kwargs): |
| | |
| | |
| | train_dataset = RegressionDataset(length=train_len) |
| | eval_dataset = RegressionDataset(length=eval_len) |
| | config = RegressionModelConfig(a=a, b=b) |
| | model = RegressionPreTrainedModel(config) |
| |
|
| | args = TrainingArguments(self.output_dir, disable_tqdm=disable_tqdm, report_to=[], **kwargs) |
| | return Trainer( |
| | model, |
| | args, |
| | train_dataset=train_dataset, |
| | eval_dataset=eval_dataset, |
| | callbacks=callbacks, |
| | ) |
| |
|
| | def check_callbacks_equality(self, cbs1, cbs2): |
| | self.assertEqual(len(cbs1), len(cbs2)) |
| |
|
| | |
| | cbs1 = sorted(cbs1, key=lambda cb: cb.__name__ if isinstance(cb, type) else cb.__class__.__name__) |
| | cbs2 = sorted(cbs2, key=lambda cb: cb.__name__ if isinstance(cb, type) else cb.__class__.__name__) |
| |
|
| | for cb1, cb2 in zip(cbs1, cbs2): |
| | if isinstance(cb1, type) and isinstance(cb2, type): |
| | self.assertEqual(cb1, cb2) |
| | elif isinstance(cb1, type) and not isinstance(cb2, type): |
| | self.assertEqual(cb1, cb2.__class__) |
| | elif not isinstance(cb1, type) and isinstance(cb2, type): |
| | self.assertEqual(cb1.__class__, cb2) |
| | else: |
| | self.assertEqual(cb1, cb2) |
| |
|
| | def get_expected_events(self, trainer): |
| | expected_events = ["on_init_end", "on_train_begin"] |
| | step = 0 |
| | train_dl_len = len(trainer.get_eval_dataloader()) |
| | evaluation_events = ["on_prediction_step"] * len(trainer.get_eval_dataloader()) + ["on_log", "on_evaluate"] |
| | for _ in range(trainer.state.num_train_epochs): |
| | expected_events.append("on_epoch_begin") |
| | for _ in range(train_dl_len): |
| | step += 1 |
| | expected_events += ["on_step_begin", "on_step_end"] |
| | if step % trainer.args.logging_steps == 0: |
| | expected_events.append("on_log") |
| | if trainer.args.evaluation_strategy == IntervalStrategy.STEPS and step % trainer.args.eval_steps == 0: |
| | expected_events += evaluation_events.copy() |
| | if step % trainer.args.save_steps == 0: |
| | expected_events.append("on_save") |
| | expected_events.append("on_epoch_end") |
| | if trainer.args.evaluation_strategy == IntervalStrategy.EPOCH: |
| | expected_events += evaluation_events.copy() |
| | expected_events += ["on_log", "on_train_end"] |
| | return expected_events |
| |
|
| | def test_init_callback(self): |
| | trainer = self.get_trainer() |
| | expected_callbacks = DEFAULT_CALLBACKS.copy() + [ProgressCallback] |
| | self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) |
| |
|
| | |
| | trainer = self.get_trainer(callbacks=[MyTestTrainerCallback]) |
| | expected_callbacks.append(MyTestTrainerCallback) |
| | self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) |
| |
|
| | |
| | trainer = self.get_trainer(disable_tqdm=True) |
| | expected_callbacks = DEFAULT_CALLBACKS.copy() + [PrinterCallback] |
| | self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) |
| |
|
| | def test_add_remove_callback(self): |
| | expected_callbacks = DEFAULT_CALLBACKS.copy() + [ProgressCallback] |
| | trainer = self.get_trainer() |
| |
|
| | |
| | trainer.remove_callback(DefaultFlowCallback) |
| | expected_callbacks.remove(DefaultFlowCallback) |
| | self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) |
| |
|
| | trainer = self.get_trainer() |
| | cb = trainer.pop_callback(DefaultFlowCallback) |
| | self.assertEqual(cb.__class__, DefaultFlowCallback) |
| | self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) |
| |
|
| | trainer.add_callback(DefaultFlowCallback) |
| | expected_callbacks.insert(0, DefaultFlowCallback) |
| | self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) |
| |
|
| | |
| | trainer = self.get_trainer() |
| | cb = trainer.callback_handler.callbacks[0] |
| | trainer.remove_callback(cb) |
| | expected_callbacks.remove(DefaultFlowCallback) |
| | self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) |
| |
|
| | trainer = self.get_trainer() |
| | cb1 = trainer.callback_handler.callbacks[0] |
| | cb2 = trainer.pop_callback(cb1) |
| | self.assertEqual(cb1, cb2) |
| | self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) |
| |
|
| | trainer.add_callback(cb1) |
| | expected_callbacks.insert(0, DefaultFlowCallback) |
| | self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) |
| |
|
| | def test_event_flow(self): |
| | import warnings |
| |
|
| | |
| | warnings.simplefilter(action="ignore", category=UserWarning) |
| |
|
| | trainer = self.get_trainer(callbacks=[MyTestTrainerCallback]) |
| | trainer.train() |
| | events = trainer.callback_handler.callbacks[-2].events |
| | self.assertEqual(events, self.get_expected_events(trainer)) |
| |
|
| | |
| | trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], logging_steps=5) |
| | trainer.train() |
| | events = trainer.callback_handler.callbacks[-2].events |
| | self.assertEqual(events, self.get_expected_events(trainer)) |
| |
|
| | trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], save_steps=5) |
| | trainer.train() |
| | events = trainer.callback_handler.callbacks[-2].events |
| | self.assertEqual(events, self.get_expected_events(trainer)) |
| |
|
| | trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], eval_steps=5, evaluation_strategy="steps") |
| | trainer.train() |
| | events = trainer.callback_handler.callbacks[-2].events |
| | self.assertEqual(events, self.get_expected_events(trainer)) |
| |
|
| | trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], evaluation_strategy="epoch") |
| | trainer.train() |
| | events = trainer.callback_handler.callbacks[-2].events |
| | self.assertEqual(events, self.get_expected_events(trainer)) |
| |
|
| | |
| | trainer = self.get_trainer( |
| | callbacks=[MyTestTrainerCallback], |
| | logging_steps=3, |
| | save_steps=10, |
| | eval_steps=5, |
| | evaluation_strategy="steps", |
| | ) |
| | trainer.train() |
| | events = trainer.callback_handler.callbacks[-2].events |
| | self.assertEqual(events, self.get_expected_events(trainer)) |
| |
|
| | |
| | with patch("transformers.trainer_callback.logger.warning") as warn_mock: |
| | trainer = self.get_trainer( |
| | callbacks=[MyTestTrainerCallback, MyTestTrainerCallback], |
| | ) |
| | assert str(MyTestTrainerCallback) in warn_mock.call_args[0][0] |
| |
|