| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import signal |
| | from unittest.mock import MagicMock, PropertyMock, patch |
| |
|
| | import pytest |
| | import torch |
| | from lightning.pytorch import Trainer |
| |
|
| | from nemo.lightning.pytorch.callbacks.preemption import PreemptionCallback |
| |
|
| |
|
| | class TestPreemptionCallback: |
| |
|
| | @pytest.fixture |
| | def callback(self): |
| | return PreemptionCallback() |
| |
|
| | @pytest.fixture |
| | def mock_trainer(self): |
| | trainer = MagicMock(spec=Trainer) |
| | trainer.should_stop = False |
| | return trainer |
| |
|
| | def test_init(self, callback): |
| | assert callback.sig == signal.SIGTERM |
| | assert not callback._interrupted |
| | assert callback._handler_context is None |
| |
|
| | def test_custom_signal(self): |
| | custom_callback = PreemptionCallback(sig=signal.SIGUSR1) |
| | assert custom_callback.sig == signal.SIGUSR1 |
| |
|
| | @pytest.mark.parametrize("initially_supported,becomes_supported", [(False, True), (False, False), (True, True)]) |
| | def test_on_train_batch_start_distributed_init( |
| | self, callback, mock_trainer, initially_supported, becomes_supported |
| | ): |
| | with ( |
| | patch.object(PreemptionCallback, '_check_preemption_support') as mock_check, |
| | patch.object(callback, '_preemption_handler') as mock_handler, |
| | ): |
| |
|
| | mock_check.side_effect = [initially_supported, becomes_supported] |
| |
|
| | callback.on_train_start(mock_trainer, None) |
| | callback.on_train_batch_start(mock_trainer, None, None, 0) |
| |
|
| | expected_call_count = 1 if initially_supported else (1 if becomes_supported else 0) |
| | assert mock_handler.call_count == expected_call_count |
| |
|
| | if initially_supported: |
| | mock_handler.assert_called_once_with() |
| | elif becomes_supported: |
| | mock_handler.assert_called_once_with() |
| | else: |
| | mock_handler.assert_not_called() |
| |
|
| | @pytest.mark.parametrize( |
| | "is_supported,interrupted,expected", |
| | [ |
| | (True, True, True), |
| | (True, False, False), |
| | (False, True, False), |
| | (False, False, False), |
| | ], |
| | ) |
| | def test_interrupted_property(self, callback, is_supported, interrupted, expected): |
| | with ( |
| | patch.object(PreemptionCallback, '_check_preemption_support', return_value=is_supported), |
| | patch('torch.distributed.broadcast'), |
| | patch('torch.tensor', return_value=torch.tensor(interrupted)), |
| | patch('torch.cuda.is_available', return_value=True), |
| | patch('torch.cuda.current_device', return_value=0), |
| | ): |
| | callback._interrupted = interrupted |
| | assert callback.interrupted == expected |
| |
|
| | def test_on_train_start(self, callback, mock_trainer): |
| | with ( |
| | patch.object(PreemptionCallback, 'preemption_supported', new_callable=PropertyMock) as mock_supported, |
| | patch.object(callback, '_preemption_handler') as mock_handler, |
| | ): |
| |
|
| | |
| | mock_supported.return_value = True |
| | callback.on_train_start(mock_trainer, None) |
| | mock_handler.assert_called_once() |
| | mock_handler.reset_mock() |
| |
|
| | |
| | mock_supported.return_value = False |
| | callback.on_train_start(mock_trainer, None) |
| | mock_handler.assert_not_called() |
| |
|
| | def test_on_train_end(self, callback, mock_trainer): |
| | mock_context = MagicMock() |
| | callback._handler_context = mock_context |
| | callback.on_train_end(mock_trainer, None) |
| | mock_context.__exit__.assert_called_once_with(None, None, None) |
| |
|
| | @pytest.mark.parametrize("interrupted", [True, False]) |
| | def test_on_train_batch_end(self, callback, mock_trainer, interrupted): |
| | with patch.object(PreemptionCallback, 'interrupted', new_callable=lambda: property(lambda self: interrupted)): |
| | if interrupted: |
| | with pytest.raises(SystemExit): |
| | callback.on_train_batch_end(mock_trainer, None, None, None, 0) |
| | else: |
| | callback.on_train_batch_end(mock_trainer, None, None, None, 0) |
| | assert mock_trainer.should_stop == interrupted |
| |
|