|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """A customized training library for the specific task."""
|
|
|
| from absl import app
|
| from absl import flags
|
| import gin
|
|
|
| from official.common import distribute_utils
|
| from official.common import flags as tfm_flags
|
| from official.core import train_lib
|
| from official.core import train_utils
|
| from official.modeling import performance
|
| from official.projects.text_classification_example import classification_example
|
|
|
| FLAGS = flags.FLAGS
|
|
|
|
|
| def main(_):
|
| gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
|
| params = train_utils.parse_configuration(FLAGS)
|
| model_dir = FLAGS.model_dir
|
| if 'train' in FLAGS.mode:
|
|
|
|
|
| train_utils.serialize_config(params, model_dir)
|
|
|
|
|
|
|
|
|
|
|
| if params.runtime.mixed_precision_dtype:
|
| performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
|
| distribution_strategy = distribute_utils.get_distribution_strategy(
|
| distribution_strategy=params.runtime.distribution_strategy,
|
| all_reduce_alg=params.runtime.all_reduce_alg,
|
| num_gpus=params.runtime.num_gpus,
|
| tpu_address=params.runtime.tpu,
|
| **params.runtime.model_parallelism())
|
|
|
| with distribution_strategy.scope():
|
| task = classification_example.ClassificationExampleTask(params.task)
|
|
|
| train_lib.run_experiment(
|
| distribution_strategy=distribution_strategy,
|
| task=task,
|
| mode=FLAGS.mode,
|
| params=params,
|
| model_dir=model_dir)
|
|
|
| train_utils.save_gin_config(FLAGS.mode, model_dir)
|
|
|
|
|
| if __name__ == '__main__':
|
| tfm_flags.define_flags()
|
| app.run(main)
|
|
|