| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import signal |
| import tempfile |
| import unittest |
| from unittest.mock import Mock, patch |
|
|
| from transformers import TrainingArguments, is_torch_available |
| from transformers.testing_utils import require_torch |
|
|
|
|
| if is_torch_available(): |
| from transformers.trainer_jit_checkpoint import CheckpointManager, JITCheckpointCallback |
|
|
| from .test_trainer import RegressionDataset, RegressionModelConfig, RegressionPreTrainedModel |
|
|
|
|
| @require_torch |
| class JITCheckpointTest(unittest.TestCase): |
| def setUp(self): |
| self.test_dir = tempfile.mkdtemp() |
|
|
| def tearDown(self): |
| import shutil |
|
|
| shutil.rmtree(self.test_dir, ignore_errors=True) |
|
|
| def get_trainer(self, enable_jit=True): |
| """Helper method to create a trainer with JIT checkpointing enabled.""" |
| from transformers import Trainer |
|
|
| model_config = RegressionModelConfig(a=1.5, b=2.5) |
| model = RegressionPreTrainedModel(model_config) |
|
|
| args = TrainingArguments( |
| output_dir=self.test_dir, |
| enable_jit_checkpoint=enable_jit, |
| per_device_train_batch_size=16, |
| learning_rate=0.1, |
| logging_steps=1, |
| num_train_epochs=1, |
| max_steps=10, |
| save_steps=10, |
| ) |
|
|
| train_dataset = RegressionDataset(length=64) |
|
|
| return Trainer(model=model, args=args, train_dataset=train_dataset) |
|
|
| def test_checkpoint_manager_initialization(self): |
| """Test CheckpointManager initialization with different configurations.""" |
| trainer = self.get_trainer() |
|
|
| |
| manager = CheckpointManager(trainer) |
| self.assertEqual(manager.trainer, trainer) |
| self.assertEqual(manager.kill_wait, 3) |
| self.assertFalse(manager.is_checkpoint_requested) |
|
|
| |
| manager_custom = CheckpointManager(trainer, kill_wait=5) |
| self.assertEqual(manager_custom.kill_wait, 5) |
|
|
| def test_signal_handler_setup(self): |
| """Test signal handler setup and restoration.""" |
| trainer = self.get_trainer() |
| manager = CheckpointManager(trainer) |
|
|
| |
| original_handler = signal.signal(signal.SIGTERM, signal.SIG_DFL) |
|
|
| try: |
| |
| manager.setup_signal_handler() |
|
|
| |
| current_handler = signal.signal(signal.SIGTERM, signal.SIG_DFL) |
| self.assertNotEqual(current_handler, signal.SIG_DFL) |
|
|
| |
| self.assertIsNotNone(manager._original_sigterm_handler) |
|
|
| finally: |
| |
| signal.signal(signal.SIGTERM, original_handler) |
|
|
| @patch("threading.Timer") |
| def test_sigterm_handler_flow(self, mock_timer): |
| """Test SIGTERM handler execution flow.""" |
| trainer = self.get_trainer() |
| manager = CheckpointManager(trainer, kill_wait=2) |
|
|
| |
| mock_timer_instance = Mock() |
| mock_timer.return_value = mock_timer_instance |
|
|
| |
| self.assertFalse(manager.is_checkpoint_requested) |
| manager._sigterm_handler(signal.SIGTERM, None) |
|
|
| |
| self.assertFalse(manager.is_checkpoint_requested) |
|
|
| |
| mock_timer.assert_called_once_with(2, manager._enable_checkpoint) |
| mock_timer_instance.start.assert_called_once() |
|
|
| |
| manager._enable_checkpoint() |
|
|
| |
| self.assertTrue(manager.is_checkpoint_requested) |
|
|
| |
| mock_timer.reset_mock() |
| manager._sigterm_handler(signal.SIGTERM, None) |
|
|
| |
| mock_timer.assert_not_called() |
|
|
| def test_toggle_checkpoint_flag(self): |
| """Test the toggle checkpoint flag method.""" |
| trainer = self.get_trainer() |
| manager = CheckpointManager(trainer) |
|
|
| |
| self.assertFalse(manager.is_checkpoint_requested) |
|
|
| |
| manager._enable_checkpoint() |
|
|
| |
| self.assertTrue(manager.is_checkpoint_requested) |
|
|
| def test_execute_jit_checkpoint(self): |
| """Test the checkpoint execution logic with sentinel file.""" |
| from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR |
|
|
| trainer = self.get_trainer() |
| manager = CheckpointManager(trainer) |
|
|
| |
| trainer._save_checkpoint = Mock() |
| trainer.state.global_step = 42 |
|
|
| |
| manager.is_checkpoint_requested = True |
|
|
| |
| manager.execute_jit_checkpoint() |
|
|
| |
| trainer._save_checkpoint.assert_called_once_with(trainer.model, trial=None) |
|
|
| |
| self.assertFalse(manager.is_checkpoint_requested) |
|
|
| |
| checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-42" |
| sentinel_file = os.path.join(self.test_dir, checkpoint_folder, "checkpoint-is-incomplete.txt") |
| self.assertFalse(os.path.exists(sentinel_file)) |
|
|
| def test_execute_jit_checkpoint_sentinel_file_cleanup(self): |
| """Test that sentinel file is cleaned up after successful checkpoint.""" |
| from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR |
|
|
| trainer = self.get_trainer() |
| manager = CheckpointManager(trainer) |
|
|
| |
| trainer._save_checkpoint = Mock() |
| trainer.state.global_step = 42 |
|
|
| checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-42" |
| sentinel_file = os.path.join(self.test_dir, checkpoint_folder, "checkpoint-is-incomplete.txt") |
|
|
| |
| manager.execute_jit_checkpoint() |
|
|
| |
| self.assertFalse(os.path.exists(sentinel_file)) |
|
|
| def test_execute_jit_checkpoint_with_exception(self): |
| """Test checkpoint execution with exception handling.""" |
| trainer = self.get_trainer() |
| manager = CheckpointManager(trainer) |
|
|
| |
| trainer._save_checkpoint = Mock(side_effect=Exception("Checkpoint failed")) |
| trainer.state.global_step = 42 |
|
|
| |
| with self.assertRaises(Exception) as context: |
| manager.execute_jit_checkpoint() |
|
|
| self.assertEqual(str(context.exception), "Checkpoint failed") |
|
|
| |
| self.assertFalse(manager.is_checkpoint_requested) |
|
|
| def test_jit_checkpoint_callback_initialization(self): |
| """Test JITCheckpointCallback initialization.""" |
| callback = JITCheckpointCallback() |
|
|
| self.assertIsNone(callback.trainer) |
| self.assertIsNone(callback.jit_manager) |
|
|
| def test_jit_checkpoint_callback_set_trainer_enabled(self): |
| """Test setting trainer with JIT checkpointing enabled.""" |
| trainer = self.get_trainer(enable_jit=True) |
| callback = JITCheckpointCallback() |
|
|
| with patch.object(CheckpointManager, "setup_signal_handler") as mock_setup: |
| callback.set_trainer(trainer) |
|
|
| self.assertEqual(callback.trainer, trainer) |
| self.assertIsNotNone(callback.jit_manager) |
| self.assertIsInstance(callback.jit_manager, CheckpointManager) |
| mock_setup.assert_called_once() |
|
|
| def test_jit_checkpoint_callback_set_trainer_disabled(self): |
| """Test setting trainer with JIT checkpointing disabled.""" |
| trainer = self.get_trainer(enable_jit=False) |
| callback = JITCheckpointCallback() |
|
|
| callback.set_trainer(trainer) |
|
|
| self.assertEqual(callback.trainer, trainer) |
| self.assertIsNone(callback.jit_manager) |
|
|
| def test_jit_checkpoint_callback_on_pre_optimizer_step(self): |
| """Test callback behavior during pre-optimizer step.""" |
| trainer = self.get_trainer() |
| callback = JITCheckpointCallback() |
| callback.set_trainer(trainer) |
|
|
| |
| control = Mock() |
| control.should_training_stop = False |
|
|
| |
| with patch.object(callback.jit_manager, "execute_jit_checkpoint") as mock_execute: |
| |
| callback.jit_manager.is_checkpoint_requested = False |
| callback.on_pre_optimizer_step(trainer.args, trainer.state, control) |
| self.assertFalse(control.should_training_stop) |
| mock_execute.assert_not_called() |
|
|
| |
| callback.jit_manager.is_checkpoint_requested = True |
| callback.on_pre_optimizer_step(trainer.args, trainer.state, control) |
| self.assertTrue(control.should_training_stop) |
| mock_execute.assert_called_once() |
|
|
| def test_jit_checkpoint_callback_on_step_begin(self): |
| """Test callback behavior at step begin.""" |
| trainer = self.get_trainer() |
| callback = JITCheckpointCallback() |
| callback.set_trainer(trainer) |
|
|
| |
| control = Mock() |
| control.should_training_stop = False |
|
|
| |
| with patch.object(callback.jit_manager, "execute_jit_checkpoint") as mock_execute: |
| |
| callback.jit_manager.is_checkpoint_requested = False |
| callback.on_step_begin(trainer.args, trainer.state, control) |
| self.assertFalse(control.should_training_stop) |
| mock_execute.assert_not_called() |
|
|
| |
| callback.jit_manager.is_checkpoint_requested = True |
| callback.on_step_begin(trainer.args, trainer.state, control) |
| self.assertTrue(control.should_training_stop) |
| mock_execute.assert_called_once() |
|
|
| def test_jit_checkpoint_callback_on_step_end(self): |
| """Test callback behavior at step end.""" |
| trainer = self.get_trainer() |
| callback = JITCheckpointCallback() |
| callback.set_trainer(trainer) |
|
|
| |
| control = Mock() |
| control.should_training_stop = False |
| control.should_save = True |
|
|
| |
| with patch.object(callback.jit_manager, "execute_jit_checkpoint") as mock_execute: |
| |
| callback.jit_manager.is_checkpoint_requested = False |
| callback.on_step_end(trainer.args, trainer.state, control) |
| self.assertFalse(control.should_training_stop) |
| mock_execute.assert_not_called() |
|
|
| |
| control.should_save = True |
|
|
| |
| callback.jit_manager.is_checkpoint_requested = True |
| callback.on_step_end(trainer.args, trainer.state, control) |
| self.assertTrue(control.should_training_stop) |
| self.assertFalse(control.should_save) |
| mock_execute.assert_called_once() |
|
|
| def test_jit_checkpoint_callback_on_epoch_end(self): |
| """Test callback behavior at epoch end.""" |
| trainer = self.get_trainer() |
| callback = JITCheckpointCallback() |
| callback.set_trainer(trainer) |
|
|
| |
| control = Mock() |
| control.should_save = True |
| control.should_training_stop = False |
|
|
| |
| with patch.object(callback.jit_manager, "execute_jit_checkpoint") as mock_execute: |
| |
| callback.jit_manager.is_checkpoint_requested = False |
| callback.on_epoch_end(trainer.args, trainer.state, control) |
| |
| self.assertTrue(control.should_save) |
| self.assertFalse(control.should_training_stop) |
| mock_execute.assert_not_called() |
|
|
| |
| control.should_save = True |
| control.should_training_stop = False |
|
|
| |
| callback.jit_manager.is_checkpoint_requested = True |
| callback.on_epoch_end(trainer.args, trainer.state, control) |
| self.assertFalse(control.should_save) |
| self.assertTrue(control.should_training_stop) |
| mock_execute.assert_called_once() |
|
|
| def test_jit_checkpoint_callback_on_train_end(self): |
| """Test signal handler restoration on training end.""" |
| trainer = self.get_trainer() |
| callback = JITCheckpointCallback() |
|
|
| |
| original_handler = signal.signal(signal.SIGTERM, signal.SIG_DFL) |
|
|
| try: |
| callback.set_trainer(trainer) |
|
|
| |
| self.assertIsNotNone(callback.jit_manager._original_sigterm_handler) |
|
|
| |
| control = Mock() |
|
|
| |
| callback.on_train_end(trainer.args, trainer.state, control) |
|
|
| |
| current_handler = signal.signal(signal.SIGTERM, signal.SIG_DFL) |
| self.assertEqual(current_handler, callback.jit_manager._original_sigterm_handler) |
|
|
| finally: |
| |
| signal.signal(signal.SIGTERM, original_handler) |
|
|
| @patch("threading.Timer") |
| def test_kill_wait_period(self, mock_timer): |
| """Test the kill wait period functionality.""" |
| trainer = self.get_trainer() |
| manager = CheckpointManager(trainer, kill_wait=5) |
|
|
| mock_timer_instance = Mock() |
| mock_timer.return_value = mock_timer_instance |
|
|
| manager._sigterm_handler(signal.SIGTERM, None) |
|
|
| |
| mock_timer.assert_called_once_with(5, manager._enable_checkpoint) |
| mock_timer_instance.start.assert_called_once() |
|
|
| def test_integration_with_trainer(self): |
| """Test integration of JIT checkpointing with Trainer.""" |
| trainer = self.get_trainer(enable_jit=True) |
|
|
| |
| jit_callbacks = [cb for cb in trainer.callback_handler.callbacks if isinstance(cb, JITCheckpointCallback)] |
| self.assertEqual(len(jit_callbacks), 1) |
|
|
| jit_callback = jit_callbacks[0] |
| self.assertIsNotNone(jit_callback.jit_manager) |
| self.assertEqual(jit_callback.trainer, trainer) |
|
|