|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Helper functions for the Keras implementations of models."""
|
|
|
| import multiprocessing
|
| import os
|
| import time
|
|
|
| from absl import logging
|
| import tensorflow as tf, tf_keras
|
|
|
| from tensorflow.python.eager import monitoring
|
|
|
| global_batch_size_gauge = monitoring.IntGauge(
|
| '/tensorflow/training/global_batch_size', 'TF training global batch size')
|
|
|
| first_batch_time_gauge = monitoring.IntGauge(
|
| '/tensorflow/training/first_batch',
|
| 'TF training start/end time for first batch (unix epoch time in us.',
|
| 'type')
|
|
|
| first_batch_start_time = first_batch_time_gauge.get_cell('start')
|
| first_batch_end_time = first_batch_time_gauge.get_cell('end')
|
|
|
|
|
| class BatchTimestamp(object):
|
| """A structure to store batch time stamp."""
|
|
|
| def __init__(self, batch_index, timestamp):
|
| self.batch_index = batch_index
|
| self.timestamp = timestamp
|
|
|
| def __repr__(self):
|
| return "'BatchTimestamp<batch_index: {}, timestamp: {}>'".format(
|
| self.batch_index, self.timestamp)
|
|
|
|
|
| class TimeHistory(tf_keras.callbacks.Callback):
|
| """Callback for Keras models."""
|
|
|
| def __init__(self, batch_size, log_steps, initial_step=0, logdir=None):
|
| """Callback for logging performance.
|
|
|
| Args:
|
| batch_size: Total batch size.
|
| log_steps: Interval of steps between logging of batch level stats.
|
| initial_step: Optional, initial step.
|
| logdir: Optional directory to write TensorBoard summaries.
|
| """
|
|
|
|
|
| self.batch_size = batch_size
|
| super(TimeHistory, self).__init__()
|
| self.log_steps = log_steps
|
| self.last_log_step = initial_step
|
| self.steps_before_epoch = initial_step
|
| self.steps_in_epoch = 0
|
| self.start_time = None
|
|
|
| global_batch_size_gauge.get_cell().set(batch_size)
|
|
|
| if logdir:
|
| self.summary_writer = tf.summary.create_file_writer(logdir)
|
| else:
|
| self.summary_writer = None
|
|
|
|
|
| self.timestamp_log = []
|
|
|
|
|
| self.epoch_runtime_log = []
|
|
|
| @property
|
| def global_steps(self):
|
| """The current 1-indexed global step."""
|
| return self.steps_before_epoch + self.steps_in_epoch
|
|
|
| @property
|
| def average_steps_per_second(self):
|
| """The average training steps per second across all epochs."""
|
| return self.global_steps / sum(self.epoch_runtime_log)
|
|
|
| @property
|
| def average_examples_per_second(self):
|
| """The average number of training examples per second across all epochs."""
|
| return self.average_steps_per_second * self.batch_size
|
|
|
| def get_examples_per_sec(self, warmup=1):
|
| """Calculates examples/sec through timestamp_log and skip warmup period."""
|
|
|
|
|
| time_log = self.timestamp_log
|
| seconds = time_log[-1].timestamp - time_log[warmup].timestamp
|
| steps = time_log[-1].batch_index - time_log[warmup].batch_index
|
| return self.batch_size * steps / seconds
|
|
|
| def get_startup_time(self, start_time_sec):
|
| return self.timestamp_log[0].timestamp - start_time_sec
|
|
|
| def on_train_end(self, logs=None):
|
| self.train_finish_time = time.time()
|
|
|
| if self.summary_writer:
|
| self.summary_writer.flush()
|
|
|
| def on_epoch_begin(self, epoch, logs=None):
|
| self.epoch_start = time.time()
|
|
|
| def on_batch_begin(self, batch, logs=None):
|
| if not self.start_time:
|
| self.start_time = time.time()
|
| if not first_batch_start_time.value():
|
| first_batch_start_time.set(int(self.start_time * 1000000))
|
|
|
|
|
| if not self.timestamp_log:
|
| self.timestamp_log.append(
|
| BatchTimestamp(self.global_steps, self.start_time))
|
|
|
| def on_batch_end(self, batch, logs=None):
|
| """Records elapse time of the batch and calculates examples per second."""
|
| if not first_batch_end_time.value():
|
| first_batch_end_time.set(int(time.time() * 1000000))
|
| self.steps_in_epoch = batch + 1
|
| steps_since_last_log = self.global_steps - self.last_log_step
|
| if steps_since_last_log >= self.log_steps:
|
| now = time.time()
|
| elapsed_time = now - self.start_time
|
| steps_per_second = steps_since_last_log / elapsed_time
|
| examples_per_second = steps_per_second * self.batch_size
|
|
|
| self.timestamp_log.append(BatchTimestamp(self.global_steps, now))
|
| logging.info(
|
| 'TimeHistory: %.2f seconds, %.2f examples/second between steps %d '
|
| 'and %d', elapsed_time, examples_per_second, self.last_log_step,
|
| self.global_steps)
|
|
|
| if self.summary_writer:
|
| with self.summary_writer.as_default():
|
| tf.summary.scalar('steps_per_second', steps_per_second,
|
| self.global_steps)
|
| tf.summary.scalar('examples_per_second', examples_per_second,
|
| self.global_steps)
|
|
|
| self.last_log_step = self.global_steps
|
| self.start_time = None
|
|
|
| def on_epoch_end(self, epoch, logs=None):
|
| epoch_run_time = time.time() - self.epoch_start
|
| self.epoch_runtime_log.append(epoch_run_time)
|
|
|
| self.steps_before_epoch += self.steps_in_epoch
|
| self.steps_in_epoch = 0
|
|
|
|
|
| class SimpleCheckpoint(tf_keras.callbacks.Callback):
|
| """Keras callback to save tf.train.Checkpoints."""
|
|
|
| def __init__(self, checkpoint_manager):
|
| super(SimpleCheckpoint, self).__init__()
|
| self.checkpoint_manager = checkpoint_manager
|
|
|
| def on_epoch_end(self, epoch, logs=None):
|
| step_counter = self.checkpoint_manager._step_counter.numpy()
|
| self.checkpoint_manager.save(checkpoint_number=step_counter)
|
|
|
|
|
| def set_session_config(enable_xla=False):
|
| """Sets the session config."""
|
| if enable_xla:
|
| tf.config.optimizer.set_jit(True)
|
|
|
|
|
|
|
| set_config_v2 = set_session_config
|
|
|
|
|
| def set_gpu_thread_mode_and_count(gpu_thread_mode, datasets_num_private_threads,
|
| num_gpus, per_gpu_thread_count):
|
| """Set GPU thread mode and count, and adjust dataset threads count."""
|
| cpu_count = multiprocessing.cpu_count()
|
| logging.info('Logical CPU cores: %s', cpu_count)
|
|
|
|
|
| per_gpu_thread_count = per_gpu_thread_count or 2
|
| os.environ['TF_GPU_THREAD_MODE'] = gpu_thread_mode
|
| os.environ['TF_GPU_THREAD_COUNT'] = str(per_gpu_thread_count)
|
| logging.info('TF_GPU_THREAD_COUNT: %s', os.environ['TF_GPU_THREAD_COUNT'])
|
| logging.info('TF_GPU_THREAD_MODE: %s', os.environ['TF_GPU_THREAD_MODE'])
|
|
|
|
|
|
|
| total_gpu_thread_count = per_gpu_thread_count * num_gpus
|
| num_runtime_threads = num_gpus
|
| if not datasets_num_private_threads:
|
| datasets_num_private_threads = min(
|
| cpu_count - total_gpu_thread_count - num_runtime_threads, num_gpus * 8)
|
| logging.info('Set datasets_num_private_threads to %s',
|
| datasets_num_private_threads)
|
|
|