| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Unit tests for training utility functions in train_lib.train_utils. |
| |
| This file covers tests for the Chrono context manager. |
| """ |
|
|
| from unittest import mock |
|
|
| from absl.testing import absltest |
| from scenic.train_lib import train_utils |
|
|
|
|
| class ChronoPausedTest(absltest.TestCase): |
| """Tests the Chrono.paused context manager for correct behavior.""" |
|
|
| @mock.patch("jax.block_until_ready", autospec=True) |
| @mock.patch("time.monotonic") |
| def test_paused_context_manager_waits_executes_the_code_block_and_resumes( |
| self, mock_monotonic, mock_block_until_ready |
| ): |
| """Tests the Chrono.paused context manager in a normal flow.""" |
| chrono = train_utils.Chrono() |
| before_pause, after_pause, after_resume = 100.0, 101.1, 105.5 |
| mock_monotonic.side_effect = [before_pause, after_pause, after_resume] |
| wait_for_ops = [mock.MagicMock()] |
|
|
| with chrono.paused(wait_for=wait_for_ops): |
| mock_block_until_ready.assert_called_once_with(wait_for_ops) |
| self.assertEqual(chrono.pause_start, before_pause) |
|
|
| self.assertIsNone(chrono.pause_start) |
| self.assertEqual(chrono.paused_time, after_pause - before_pause) |
| self.assertEqual(mock_monotonic.call_count, 3) |
|
|
| @mock.patch("jax.block_until_ready", autospec=True) |
| @mock.patch("time.monotonic") |
| def test_paused_context_manager_with_exception_calls_resume( |
| self, mock_monotonic, mock_block_until_ready |
| ): |
| """Tests that Chrono.resume is called even if an exception occurs.""" |
| chrono = train_utils.Chrono() |
| before_pause, after_pause, after_resume = 100.0, 101.1, 105.5 |
| mock_monotonic.side_effect = [before_pause, after_pause, after_resume] |
| wait_for_ops = ("dummy_op",) |
| custom_exception = ValueError("Test exception inside context") |
|
|
| |
| |
| |
| with self.assertRaises(ValueError) as context: |
| with chrono.paused(wait_for=wait_for_ops): |
| mock_block_until_ready.assert_called_once_with(wait_for_ops) |
| self.assertEqual(chrono.pause_start, before_pause) |
| raise custom_exception |
| self.assertEqual(context.exception, custom_exception) |
|
|
| self.assertIsNone(chrono.pause_start) |
| self.assertEqual(chrono.paused_time, after_pause - before_pause) |
| self.assertEqual(mock_monotonic.call_count, 3) |
|
|
|
|
| if __name__ == "__main__": |
| absltest.main() |
|
|