|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """TensorFlow Model Garden Vision training driver."""
|
|
|
| from absl import app
|
| from absl import flags
|
| import gin
|
|
|
| from official.common import distribute_utils
|
| from official.common import flags as tfm_flags
|
| from official.core import task_factory
|
| from official.core import train_lib
|
| from official.core import train_utils
|
| from official.modeling import performance
|
|
|
| from official.projects.pix2seq.configs import pix2seq
|
| from official.projects.pix2seq.tasks import pix2seq_task
|
|
|
|
|
| FLAGS = flags.FLAGS
|
|
|
|
|
| def main(_):
|
| gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
|
| params = train_utils.parse_configuration(FLAGS)
|
| model_dir = FLAGS.model_dir
|
| if 'train' in FLAGS.mode:
|
|
|
|
|
| train_utils.serialize_config(params, model_dir)
|
|
|
|
|
|
|
|
|
|
|
| if params.runtime.mixed_precision_dtype:
|
| performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
|
| distribution_strategy = distribute_utils.get_distribution_strategy(
|
| distribution_strategy=params.runtime.distribution_strategy,
|
| all_reduce_alg=params.runtime.all_reduce_alg,
|
| num_gpus=params.runtime.num_gpus,
|
| tpu_address=params.runtime.tpu,
|
| )
|
| with distribution_strategy.scope():
|
| task = task_factory.get_task(params.task, logging_dir=model_dir)
|
|
|
| train_lib.run_experiment(
|
| distribution_strategy=distribution_strategy,
|
| task=task,
|
| mode=FLAGS.mode,
|
| params=params,
|
| model_dir=model_dir,
|
| )
|
|
|
| train_utils.save_gin_config(FLAGS.mode, model_dir)
|
|
|
|
|
| if __name__ == '__main__':
|
| tfm_flags.define_flags()
|
| flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
|
| app.run(main)
|
|
|