|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| r"""Creates and runs `Estimator` for object detection model on TPUs.
|
|
|
| This uses the TPUEstimator API to define and run a model in TRAIN/EVAL modes.
|
| """
|
|
|
|
|
| from __future__ import absolute_import
|
| from __future__ import division
|
| from __future__ import print_function
|
|
|
| from absl import flags
|
| import tensorflow.compat.v1 as tf
|
| from tensorflow.compat.v1 import estimator as tf_estimator
|
|
|
|
|
| from object_detection import model_lib
|
|
|
| tf.flags.DEFINE_bool('use_tpu', True, 'Use TPUs rather than plain CPUs')
|
|
|
|
|
| flags.DEFINE_string(
|
| 'gcp_project',
|
| default=None,
|
| help='Project name for the Cloud TPU-enabled project. If not specified, we '
|
| 'will attempt to automatically detect the GCE project from metadata.')
|
| flags.DEFINE_string(
|
| 'tpu_zone',
|
| default=None,
|
| help='GCE zone where the Cloud TPU is located in. If not specified, we '
|
| 'will attempt to automatically detect the GCE project from metadata.')
|
| flags.DEFINE_string(
|
| 'tpu_name',
|
| default=None,
|
| help='Name of the Cloud TPU for Cluster Resolvers.')
|
|
|
| flags.DEFINE_integer('num_shards', 8, 'Number of shards (TPU cores).')
|
| flags.DEFINE_integer('iterations_per_loop', 100,
|
| 'Number of iterations per TPU training loop.')
|
|
|
|
|
|
|
| flags.DEFINE_string('mode', 'train',
|
| 'Mode to run: train, eval')
|
| flags.DEFINE_integer('train_batch_size', None, 'Batch size for training. If '
|
| 'this is not provided, batch size is read from training '
|
| 'config.')
|
| flags.DEFINE_integer('num_train_steps', None, 'Number of train steps.')
|
| flags.DEFINE_boolean('eval_training_data', False,
|
| 'If training data should be evaluated for this job.')
|
| flags.DEFINE_integer('sample_1_of_n_eval_examples', 1, 'Will sample one of '
|
| 'every n eval input examples, where n is provided.')
|
| flags.DEFINE_integer('sample_1_of_n_eval_on_train_examples', 5, 'Will sample '
|
| 'one of every n train input examples for evaluation, '
|
| 'where n is provided. This is only used if '
|
| '`eval_training_data` is True.')
|
| flags.DEFINE_string(
|
| 'model_dir', None, 'Path to output model directory '
|
| 'where event and checkpoint files will be written.')
|
| flags.DEFINE_string('pipeline_config_path', None, 'Path to pipeline config '
|
| 'file.')
|
| flags.DEFINE_integer(
|
| 'max_eval_retries', 0, 'If running continuous eval, the maximum number of '
|
| 'retries upon encountering tf.errors.InvalidArgumentError. If negative, '
|
| 'will always retry the evaluation.'
|
| )
|
|
|
| FLAGS = tf.flags.FLAGS
|
|
|
|
|
| def main(unused_argv):
|
| flags.mark_flag_as_required('model_dir')
|
| flags.mark_flag_as_required('pipeline_config_path')
|
|
|
| tpu_cluster_resolver = (
|
| tf.distribute.cluster_resolver.TPUClusterResolver(
|
| tpu=[FLAGS.tpu_name], zone=FLAGS.tpu_zone, project=FLAGS.gcp_project))
|
| tpu_grpc_url = tpu_cluster_resolver.get_master()
|
|
|
| config = tf_estimator.tpu.RunConfig(
|
| master=tpu_grpc_url,
|
| evaluation_master=tpu_grpc_url,
|
| model_dir=FLAGS.model_dir,
|
| tpu_config=tf_estimator.tpu.TPUConfig(
|
| iterations_per_loop=FLAGS.iterations_per_loop,
|
| num_shards=FLAGS.num_shards))
|
|
|
| kwargs = {}
|
| if FLAGS.train_batch_size:
|
| kwargs['batch_size'] = FLAGS.train_batch_size
|
|
|
| train_and_eval_dict = model_lib.create_estimator_and_inputs(
|
| run_config=config,
|
| pipeline_config_path=FLAGS.pipeline_config_path,
|
| train_steps=FLAGS.num_train_steps,
|
| sample_1_of_n_eval_examples=FLAGS.sample_1_of_n_eval_examples,
|
| sample_1_of_n_eval_on_train_examples=(
|
| FLAGS.sample_1_of_n_eval_on_train_examples),
|
| use_tpu_estimator=True,
|
| use_tpu=FLAGS.use_tpu,
|
| num_shards=FLAGS.num_shards,
|
| save_final_config=FLAGS.mode == 'train',
|
| **kwargs)
|
| estimator = train_and_eval_dict['estimator']
|
| train_input_fn = train_and_eval_dict['train_input_fn']
|
| eval_input_fns = train_and_eval_dict['eval_input_fns']
|
| eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn']
|
| train_steps = train_and_eval_dict['train_steps']
|
|
|
| if FLAGS.mode == 'train':
|
| estimator.train(input_fn=train_input_fn, max_steps=train_steps)
|
|
|
|
|
| if FLAGS.mode == 'eval':
|
| if FLAGS.eval_training_data:
|
| name = 'training_data'
|
| input_fn = eval_on_train_input_fn
|
| else:
|
| name = 'validation_data'
|
|
|
| input_fn = eval_input_fns[0]
|
| model_lib.continuous_eval(estimator, FLAGS.model_dir, input_fn, train_steps,
|
| name, FLAGS.max_eval_retries)
|
|
|
|
|
| if __name__ == '__main__':
|
| tf.app.run()
|
|
|