|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| r"""Training driver.
|
|
|
| Commandline:
|
| python -m official.vision.beta.projects.assemblenet.trian \
|
| --mode=train_and_eval --experiment=assemblenetplus_ucf101 \
|
| --model_dir='YOUR MODEL SAVE GS BUCKET' \
|
| --config_file=./official/vision/beta/projects/assemblenet/ \
|
| --ucf101_assemblenet_plus_tpu.yaml \
|
| --tpu=TPU_NAME
|
| """
|
|
|
| from absl import app
|
| from absl import flags
|
| from absl import logging
|
| import gin
|
|
|
| from official.common import distribute_utils
|
| from official.common import flags as tfm_flags
|
| from official.core import task_factory
|
| from official.core import train_lib
|
| from official.core import train_utils
|
| from official.modeling import performance
|
|
|
| from official.projects.assemblenet.configs import assemblenet as asn_configs
|
| from official.projects.assemblenet.modeling import assemblenet as asn
|
| from official.projects.assemblenet.modeling import assemblenet_plus as asnp
|
| from official.vision import registry_imports
|
|
|
|
|
| FLAGS = flags.FLAGS
|
|
|
|
|
| def main(_):
|
| gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
|
| params = train_utils.parse_configuration(FLAGS)
|
| model_dir = FLAGS.model_dir
|
| if 'train' in FLAGS.mode:
|
|
|
|
|
| train_utils.serialize_config(params, model_dir)
|
|
|
| if 'train_and_eval' in FLAGS.mode:
|
| assert (params.task.train_data.feature_shape ==
|
| params.task.validation_data.feature_shape), (
|
| f'train {params.task.train_data.feature_shape} != validate '
|
| f'{params.task.validation_data.feature_shape}')
|
|
|
| if 'assemblenet' in FLAGS.experiment:
|
| if 'plus' in FLAGS.experiment:
|
| if 'eval' in FLAGS.mode:
|
|
|
|
|
|
|
| params.task.model.backbone.assemblenet_plus.num_frames = (
|
| params.task.validation_data.feature_shape[0])
|
| shape = params.task.validation_data.feature_shape
|
| else:
|
| params.task.model.backbone.assemblenet_plus.num_frames = (
|
| params.task.train_data.feature_shape[0])
|
| shape = params.task.train_data.feature_shape
|
| logging.info('mode %r num_frames %r feature shape %r', FLAGS.mode,
|
| params.task.model.backbone.assemblenet_plus.num_frames,
|
| shape)
|
|
|
| else:
|
| if 'eval' in FLAGS.mode:
|
|
|
|
|
| params.task.model.backbone.assemblenet.num_frames = (
|
| params.task.validation_data.feature_shape[0])
|
| shape = params.task.validation_data.feature_shape
|
| else:
|
| params.task.model.backbone.assemblenet.num_frames = (
|
| params.task.train_data.feature_shape[0])
|
| shape = params.task.train_data.feature_shape
|
| logging.info('mode %r num_frames %r feature shape %r', FLAGS.mode,
|
| params.task.model.backbone.assemblenet.num_frames, shape)
|
|
|
|
|
|
|
|
|
|
|
| if params.runtime.mixed_precision_dtype:
|
| performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
|
| distribution_strategy = distribute_utils.get_distribution_strategy(
|
| distribution_strategy=params.runtime.distribution_strategy,
|
| all_reduce_alg=params.runtime.all_reduce_alg,
|
| num_gpus=params.runtime.num_gpus,
|
| tpu_address=params.runtime.tpu)
|
| with distribution_strategy.scope():
|
| task = task_factory.get_task(params.task, logging_dir=model_dir)
|
|
|
| train_lib.run_experiment(
|
| distribution_strategy=distribution_strategy,
|
| task=task,
|
| mode=FLAGS.mode,
|
| params=params,
|
| model_dir=model_dir)
|
|
|
| train_utils.save_gin_config(FLAGS.mode, model_dir)
|
|
|
| if __name__ == '__main__':
|
| tfm_flags.define_flags()
|
| flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
|
| app.run(main)
|
|
|