|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """TensorFlow Model Garden Vision training driver."""
|
|
|
| from absl import app
|
| from absl import flags
|
| from absl import logging
|
| import gin
|
| import tensorflow as tf, tf_keras
|
|
|
| from official.common import distribute_utils
|
| from official.common import flags as tfm_flags
|
| from official.core import task_factory
|
| from official.core import train_lib
|
| from official.core import train_utils
|
| from official.modeling import performance
|
| from official.vision import registry_imports
|
| from official.vision.utils import summary_manager
|
|
|
|
|
| FLAGS = flags.FLAGS
|
|
|
| flags.DEFINE_bool(
|
| 'enable_async_checkpointing',
|
| default=True,
|
| help='A boolean indicating whether to enable async checkpoint saving')
|
|
|
|
|
| def _run_experiment_with_preemption_recovery(params, model_dir):
|
| """Runs experiment and tries to reconnect when encounting a preemption."""
|
| keep_training = True
|
| while keep_training:
|
| preemption_watcher = None
|
| try:
|
| distribution_strategy = distribute_utils.get_distribution_strategy(
|
| distribution_strategy=params.runtime.distribution_strategy,
|
| all_reduce_alg=params.runtime.all_reduce_alg,
|
| num_gpus=params.runtime.num_gpus,
|
| tpu_address=params.runtime.tpu)
|
| with distribution_strategy.scope():
|
| task = task_factory.get_task(params.task, logging_dir=model_dir)
|
|
|
| preemption_watcher = None
|
|
|
|
|
| train_lib.run_experiment(
|
| distribution_strategy=distribution_strategy,
|
| task=task,
|
| mode=FLAGS.mode,
|
| params=params,
|
| model_dir=model_dir,
|
| summary_manager=None,
|
| eval_summary_manager=summary_manager.maybe_build_eval_summary_manager(
|
| params=params, model_dir=model_dir
|
| ),
|
| enable_async_checkpointing=FLAGS.enable_async_checkpointing,
|
| )
|
|
|
| keep_training = False
|
| except tf.errors.OpError as e:
|
| if preemption_watcher and preemption_watcher.preemption_message:
|
| preemption_watcher.block_until_worker_exit()
|
| logging.info(
|
| 'Some TPU workers had been preempted (message: %s), '
|
| 'retarting training from the last checkpoint...',
|
| preemption_watcher.preemption_message)
|
| keep_training = True
|
| else:
|
| raise e from None
|
|
|
|
|
| def main(_):
|
| gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
|
| params = train_utils.parse_configuration(FLAGS)
|
| model_dir = FLAGS.model_dir
|
| if 'train' in FLAGS.mode:
|
|
|
|
|
| train_utils.serialize_config(params, model_dir)
|
|
|
|
|
|
|
|
|
|
|
| if params.runtime.mixed_precision_dtype:
|
| performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
|
|
|
| _run_experiment_with_preemption_recovery(params, model_dir)
|
| train_utils.save_gin_config(FLAGS.mode, model_dir)
|
|
|
| if __name__ == '__main__':
|
| tfm_flags.define_flags()
|
| flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
|
| app.run(main)
|
|
|