|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Tests for orbit.utils.tpu_summaries."""
|
|
|
| import functools
|
| import os
|
|
|
| from orbit.utils import common
|
| from orbit.utils import tpu_summaries
|
|
|
| import tensorflow as tf, tf_keras
|
|
|
|
|
| class TrainFunctionWithSummaries(tpu_summaries.OptionalSummariesFunction):
|
| """Implements a two-program approach for summaries on TPU."""
|
|
|
| def __call__(self, num_steps):
|
| if tf.summary.should_record_summaries():
|
| output = self.with_summaries(tf.constant(1))
|
| num_steps -= 1
|
| if num_steps >= 1:
|
| output = self.without_summaries(num_steps)
|
| return output
|
|
|
|
|
| def train_function_with_summaries(function=None, **kwargs):
|
| if function is not None:
|
| return TrainFunctionWithSummaries(function, **kwargs)
|
| return functools.partial(TrainFunctionWithSummaries, **kwargs)
|
|
|
|
|
| class DummyTrainer(tf.Module):
|
|
|
| def __init__(self):
|
| self.step_counter = common.create_global_step()
|
|
|
| @train_function_with_summaries
|
| def train_with_tpu_summary_optimization(self, num_steps):
|
| for _ in tf.range(num_steps):
|
| tf.summary.scalar("step", self.step_counter, step=self.step_counter)
|
| self.step_counter.assign_add(1)
|
| return self.step_counter
|
|
|
| @train_function_with_summaries(
|
| input_signature=[tf.TensorSpec((), dtype=tf.int32)])
|
| def train_with_tpu_summary_optimization_and_input_signature(self, num_steps):
|
| for _ in tf.range(num_steps):
|
| tf.summary.scalar("step", self.step_counter, step=self.step_counter)
|
| self.step_counter.assign_add(1)
|
| return self.step_counter
|
|
|
| def train_with_tpu_summary_optimization_no_decorator(self, num_steps):
|
| for _ in tf.range(num_steps):
|
| tf.summary.scalar("step", self.step_counter, step=self.step_counter)
|
| self.step_counter.assign_add(1)
|
| return self.step_counter
|
|
|
|
|
| class TpuSummariesTest(tf.test.TestCase):
|
|
|
| def setUp(self):
|
| super().setUp()
|
| self.trainer = DummyTrainer()
|
|
|
| def _get_events_from_logdir(self, logdir):
|
| event_files = tf.io.gfile.listdir(logdir)
|
| self.assertLen(event_files, 1)
|
| path = os.path.join(logdir, event_files[0])
|
| events = list(tf.compat.v1.train.summary_iterator(path))
|
| return [event for event in events if event.WhichOneof("what") == "summary"]
|
|
|
| def _validate_tpu_summary_optimization(self, function, *args, **kwargs):
|
| logdir = self.get_temp_dir()
|
| with tf.summary.create_file_writer(logdir).as_default():
|
| with tf.summary.record_if(lambda: self.trainer.step_counter % 20 == 0):
|
| for _ in range(4):
|
| output = function(tf.constant(10), *args, **kwargs)
|
| events = self._get_events_from_logdir(logdir)
|
| self.assertLen(events, 2)
|
| self.assertEqual(events[0].step, 0)
|
| self.assertEqual(events[1].step, 20)
|
| return output
|
|
|
| def test_train_with_tpu_summary_optimization(self):
|
| output = self._validate_tpu_summary_optimization(
|
| self.trainer.train_with_tpu_summary_optimization)
|
| self.assertEqual(output, self.trainer.step_counter.numpy())
|
|
|
| def test_train_with_tpu_summary_optimization_no_decorator(self):
|
| optimized = train_function_with_summaries(
|
| self.trainer.train_with_tpu_summary_optimization_no_decorator)
|
| output = self._validate_tpu_summary_optimization(optimized)
|
| self.assertEqual(output, self.trainer.step_counter.numpy())
|
|
|
| def test_train_with_tpu_summary_optimization_and_input_signature(self):
|
| output = self._validate_tpu_summary_optimization(
|
| self.trainer.train_with_tpu_summary_optimization_and_input_signature)
|
| self.assertEqual(output, self.trainer.step_counter.numpy())
|
| function = self.trainer.train_with_tpu_summary_optimization_and_input_signature
|
| expected = (tf.TensorSpec((), dtype=tf.int32),)
|
| input_signature = function.with_summaries.input_signature
|
| self.assertEqual(input_signature, expected)
|
| input_signature = function.without_summaries.input_signature
|
| self.assertEqual(input_signature, expected)
|
|
|
|
|
| if __name__ == "__main__":
|
| tf.test.main()
|
|
|