owlv2 / scenic /train_lib /tests /test_train_utils.py
fcxfcx's picture
Upload 2446 files
1327f34 verified
# Copyright 2025 The Scenic Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for training utility functions in train_lib.train_utils.
This file covers tests for the Chrono context manager.
"""
from unittest import mock
from absl.testing import absltest
from scenic.train_lib import train_utils
class ChronoPausedTest(absltest.TestCase):
"""Tests the Chrono.paused context manager for correct behavior."""
@mock.patch("jax.block_until_ready", autospec=True)
@mock.patch("time.monotonic")
def test_paused_context_manager_waits_executes_the_code_block_and_resumes(
self, mock_monotonic, mock_block_until_ready
):
"""Tests the Chrono.paused context manager in a normal flow."""
chrono = train_utils.Chrono()
before_pause, after_pause, after_resume = 100.0, 101.1, 105.5
mock_monotonic.side_effect = [before_pause, after_pause, after_resume]
wait_for_ops = [mock.MagicMock()] # Dummy operations to await.
with chrono.paused(wait_for=wait_for_ops):
mock_block_until_ready.assert_called_once_with(wait_for_ops)
self.assertEqual(chrono.pause_start, before_pause)
self.assertIsNone(chrono.pause_start) # Should be reset by resume
self.assertEqual(chrono.paused_time, after_pause - before_pause)
self.assertEqual(mock_monotonic.call_count, 3) # init, pause, and resume
@mock.patch("jax.block_until_ready", autospec=True)
@mock.patch("time.monotonic")
def test_paused_context_manager_with_exception_calls_resume(
self, mock_monotonic, mock_block_until_ready
):
"""Tests that Chrono.resume is called even if an exception occurs."""
chrono = train_utils.Chrono()
before_pause, after_pause, after_resume = 100.0, 101.1, 105.5
mock_monotonic.side_effect = [before_pause, after_pause, after_resume]
wait_for_ops = ("dummy_op",)
custom_exception = ValueError("Test exception inside context")
# Disable linting since the assertion against the exception must be done
# within the context manager. The assertions below the context blocks are
# not affected by the exception, despite the highlighting (or dimming).
with self.assertRaises(ValueError) as context: # pylint: disable=g-error-prone-assert-raises
with chrono.paused(wait_for=wait_for_ops):
mock_block_until_ready.assert_called_once_with(wait_for_ops)
self.assertEqual(chrono.pause_start, before_pause)
raise custom_exception
self.assertEqual(context.exception, custom_exception)
self.assertIsNone(chrono.pause_start) # Should be reset by resume
self.assertEqual(chrono.paused_time, after_pause - before_pause)
self.assertEqual(mock_monotonic.call_count, 3) # init, pause, and resume
if __name__ == "__main__":
absltest.main()