|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Training script for the DeepLab model.
|
|
|
| See model.py for more details and usage.
|
| """
|
|
|
| from __future__ import absolute_import
|
| from __future__ import division
|
| from __future__ import print_function
|
| import six
|
| import tensorflow as tf
|
| from tensorflow.contrib import quantize as contrib_quantize
|
| from tensorflow.contrib import tfprof as contrib_tfprof
|
| from deeplab import common
|
| from deeplab import model
|
| from deeplab.datasets import data_generator
|
| from deeplab.utils import train_utils
|
| from deployment import model_deploy
|
|
|
| slim = tf.contrib.slim
|
| flags = tf.app.flags
|
| FLAGS = flags.FLAGS
|
|
|
|
|
|
|
| flags.DEFINE_integer('num_clones', 1, 'Number of clones to deploy.')
|
|
|
| flags.DEFINE_boolean('clone_on_cpu', False, 'Use CPUs to deploy clones.')
|
|
|
| flags.DEFINE_integer('num_replicas', 1, 'Number of worker replicas.')
|
|
|
| flags.DEFINE_integer('startup_delay_steps', 15,
|
| 'Number of training steps between replicas startup.')
|
|
|
| flags.DEFINE_integer(
|
| 'num_ps_tasks', 0,
|
| 'The number of parameter servers. If the value is 0, then '
|
| 'the parameters are handled locally by the worker.')
|
|
|
| flags.DEFINE_string('master', '', 'BNS name of the tensorflow server')
|
|
|
| flags.DEFINE_integer('task', 0, 'The task ID.')
|
|
|
|
|
|
|
| flags.DEFINE_string('train_logdir', None,
|
| 'Where the checkpoint and logs are stored.')
|
|
|
| flags.DEFINE_integer('log_steps', 10,
|
| 'Display logging information at every log_steps.')
|
|
|
| flags.DEFINE_integer('save_interval_secs', 1200,
|
| 'How often, in seconds, we save the model to disk.')
|
|
|
| flags.DEFINE_integer('save_summaries_secs', 600,
|
| 'How often, in seconds, we compute the summaries.')
|
|
|
| flags.DEFINE_boolean(
|
| 'save_summaries_images', False,
|
| 'Save sample inputs, labels, and semantic predictions as '
|
| 'images to summary.')
|
|
|
|
|
|
|
| flags.DEFINE_string('profile_logdir', None,
|
| 'Where the profile files are stored.')
|
|
|
|
|
|
|
| flags.DEFINE_enum('optimizer', 'momentum', ['momentum', 'adam'],
|
| 'Which optimizer to use.')
|
|
|
|
|
|
|
|
|
| flags.DEFINE_enum('learning_policy', 'poly', ['poly', 'step'],
|
| 'Learning rate policy for training.')
|
|
|
|
|
|
|
| flags.DEFINE_float('base_learning_rate', .0001,
|
| 'The base learning rate for model training.')
|
|
|
| flags.DEFINE_float('decay_steps', 0.0,
|
| 'Decay steps for polynomial learning rate schedule.')
|
|
|
| flags.DEFINE_float('end_learning_rate', 0.0,
|
| 'End learning rate for polynomial learning rate schedule.')
|
|
|
| flags.DEFINE_float('learning_rate_decay_factor', 0.1,
|
| 'The rate to decay the base learning rate.')
|
|
|
| flags.DEFINE_integer('learning_rate_decay_step', 2000,
|
| 'Decay the base learning rate at a fixed step.')
|
|
|
| flags.DEFINE_float('learning_power', 0.9,
|
| 'The power value used in the poly learning policy.')
|
|
|
| flags.DEFINE_integer('training_number_of_steps', 30000,
|
| 'The number of steps used for training')
|
|
|
| flags.DEFINE_float('momentum', 0.9, 'The momentum value to use')
|
|
|
|
|
| flags.DEFINE_float('adam_learning_rate', 0.001,
|
| 'Learning rate for the adam optimizer.')
|
| flags.DEFINE_float('adam_epsilon', 1e-08, 'Adam optimizer epsilon.')
|
|
|
|
|
|
|
|
|
| flags.DEFINE_integer('train_batch_size', 8,
|
| 'The number of images in each batch during training.')
|
|
|
|
|
|
|
| flags.DEFINE_float('weight_decay', 0.00004,
|
| 'The value of the weight decay for training.')
|
|
|
| flags.DEFINE_list('train_crop_size', '513,513',
|
| 'Image crop size [height, width] during training.')
|
|
|
| flags.DEFINE_float(
|
| 'last_layer_gradient_multiplier', 1.0,
|
| 'The gradient multiplier for last layers, which is used to '
|
| 'boost the gradient of last layers if the value > 1.')
|
|
|
| flags.DEFINE_boolean('upsample_logits', True,
|
| 'Upsample logits during training.')
|
|
|
|
|
|
|
| flags.DEFINE_float(
|
| 'drop_path_keep_prob', 1.0,
|
| 'Probability to keep each path in the NAS cell when training.')
|
|
|
|
|
|
|
| flags.DEFINE_string('tf_initial_checkpoint', None,
|
| 'The initial checkpoint in tensorflow format.')
|
|
|
|
|
| flags.DEFINE_boolean('initialize_last_layer', True,
|
| 'Initialize the last layer.')
|
|
|
| flags.DEFINE_boolean('last_layers_contain_logits_only', False,
|
| 'Only consider logits as last layers or not.')
|
|
|
| flags.DEFINE_integer('slow_start_step', 0,
|
| 'Training model with small learning rate for few steps.')
|
|
|
| flags.DEFINE_float('slow_start_learning_rate', 1e-4,
|
| 'Learning rate employed during slow start.')
|
|
|
|
|
|
|
| flags.DEFINE_boolean('fine_tune_batch_norm', True,
|
| 'Fine tune the batch norm parameters or not.')
|
|
|
| flags.DEFINE_float('min_scale_factor', 0.5,
|
| 'Mininum scale factor for data augmentation.')
|
|
|
| flags.DEFINE_float('max_scale_factor', 2.,
|
| 'Maximum scale factor for data augmentation.')
|
|
|
| flags.DEFINE_float('scale_factor_step_size', 0.25,
|
| 'Scale factor step size for data augmentation.')
|
|
|
|
|
|
|
|
|
| flags.DEFINE_multi_integer('atrous_rates', None,
|
| 'Atrous rates for atrous spatial pyramid pooling.')
|
|
|
| flags.DEFINE_integer('output_stride', 16,
|
| 'The ratio of input to output spatial resolution.')
|
|
|
|
|
| flags.DEFINE_integer(
|
| 'hard_example_mining_step', 0,
|
| 'The training step in which exact hard example mining kicks off. Note we '
|
| 'gradually reduce the mining percent to the specified '
|
| 'top_k_percent_pixels. For example, if hard_example_mining_step=100K and '
|
| 'top_k_percent_pixels=0.25, then mining percent will gradually reduce from '
|
| '100% to 25% until 100K steps after which we only mine top 25% pixels.')
|
|
|
| flags.DEFINE_float(
|
| 'top_k_percent_pixels', 1.0,
|
| 'The top k percent pixels (in terms of the loss values) used to compute '
|
| 'loss during training. This is useful for hard pixel mining.')
|
|
|
|
|
| flags.DEFINE_integer(
|
| 'quantize_delay_step', -1,
|
| 'Steps to start quantized training. If < 0, will not quantize model.')
|
|
|
|
|
| flags.DEFINE_string('dataset', 'pascal_voc_seg',
|
| 'Name of the segmentation dataset.')
|
|
|
| flags.DEFINE_string('train_split', 'train',
|
| 'Which split of the dataset to be used for training')
|
|
|
| flags.DEFINE_string('dataset_dir', None, 'Where the dataset reside.')
|
|
|
|
|
| def _build_deeplab(iterator, outputs_to_num_classes, ignore_label):
|
| """Builds a clone of DeepLab.
|
|
|
| Args:
|
| iterator: An iterator of type tf.data.Iterator for images and labels.
|
| outputs_to_num_classes: A map from output type to the number of classes. For
|
| example, for the task of semantic segmentation with 21 semantic classes,
|
| we would have outputs_to_num_classes['semantic'] = 21.
|
| ignore_label: Ignore label.
|
| """
|
| samples = iterator.get_next()
|
|
|
|
|
| samples[common.IMAGE] = tf.identity(samples[common.IMAGE], name=common.IMAGE)
|
| samples[common.LABEL] = tf.identity(samples[common.LABEL], name=common.LABEL)
|
|
|
| model_options = common.ModelOptions(
|
| outputs_to_num_classes=outputs_to_num_classes,
|
| crop_size=[int(sz) for sz in FLAGS.train_crop_size],
|
| atrous_rates=FLAGS.atrous_rates,
|
| output_stride=FLAGS.output_stride)
|
|
|
| outputs_to_scales_to_logits = model.multi_scale_logits(
|
| samples[common.IMAGE],
|
| model_options=model_options,
|
| image_pyramid=FLAGS.image_pyramid,
|
| weight_decay=FLAGS.weight_decay,
|
| is_training=True,
|
| fine_tune_batch_norm=FLAGS.fine_tune_batch_norm,
|
| nas_training_hyper_parameters={
|
| 'drop_path_keep_prob': FLAGS.drop_path_keep_prob,
|
| 'total_training_steps': FLAGS.training_number_of_steps,
|
| })
|
|
|
|
|
| output_type_dict = outputs_to_scales_to_logits[common.OUTPUT_TYPE]
|
| output_type_dict[model.MERGED_LOGITS_SCOPE] = tf.identity(
|
| output_type_dict[model.MERGED_LOGITS_SCOPE], name=common.OUTPUT_TYPE)
|
|
|
| for output, num_classes in six.iteritems(outputs_to_num_classes):
|
| train_utils.add_softmax_cross_entropy_loss_for_each_scale(
|
| outputs_to_scales_to_logits[output],
|
| samples[common.LABEL],
|
| num_classes,
|
| ignore_label,
|
| loss_weight=model_options.label_weights,
|
| upsample_logits=FLAGS.upsample_logits,
|
| hard_example_mining_step=FLAGS.hard_example_mining_step,
|
| top_k_percent_pixels=FLAGS.top_k_percent_pixels,
|
| scope=output)
|
|
|
|
|
| def main(unused_argv):
|
| tf.logging.set_verbosity(tf.logging.INFO)
|
|
|
| config = model_deploy.DeploymentConfig(
|
| num_clones=FLAGS.num_clones,
|
| clone_on_cpu=FLAGS.clone_on_cpu,
|
| replica_id=FLAGS.task,
|
| num_replicas=FLAGS.num_replicas,
|
| num_ps_tasks=FLAGS.num_ps_tasks)
|
|
|
|
|
| assert FLAGS.train_batch_size % config.num_clones == 0, (
|
| 'Training batch size not divisble by number of clones (GPUs).')
|
|
|
| clone_batch_size = FLAGS.train_batch_size // config.num_clones
|
|
|
| tf.gfile.MakeDirs(FLAGS.train_logdir)
|
| tf.logging.info('Training on %s set', FLAGS.train_split)
|
|
|
| with tf.Graph().as_default() as graph:
|
| with tf.device(config.inputs_device()):
|
| dataset = data_generator.Dataset(
|
| dataset_name=FLAGS.dataset,
|
| split_name=FLAGS.train_split,
|
| dataset_dir=FLAGS.dataset_dir,
|
| batch_size=clone_batch_size,
|
| crop_size=[int(sz) for sz in FLAGS.train_crop_size],
|
| min_resize_value=FLAGS.min_resize_value,
|
| max_resize_value=FLAGS.max_resize_value,
|
| resize_factor=FLAGS.resize_factor,
|
| min_scale_factor=FLAGS.min_scale_factor,
|
| max_scale_factor=FLAGS.max_scale_factor,
|
| scale_factor_step_size=FLAGS.scale_factor_step_size,
|
| model_variant=FLAGS.model_variant,
|
| num_readers=4,
|
| is_training=True,
|
| should_shuffle=True,
|
| should_repeat=True)
|
|
|
|
|
| with tf.device(config.variables_device()):
|
| global_step = tf.train.get_or_create_global_step()
|
|
|
|
|
| model_fn = _build_deeplab
|
| model_args = (dataset.get_one_shot_iterator(), {
|
| common.OUTPUT_TYPE: dataset.num_of_classes
|
| }, dataset.ignore_label)
|
| clones = model_deploy.create_clones(config, model_fn, args=model_args)
|
|
|
|
|
|
|
| first_clone_scope = config.clone_scope(0)
|
| update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)
|
|
|
|
|
| summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
|
|
|
|
|
| for model_var in tf.model_variables():
|
| summaries.add(tf.summary.histogram(model_var.op.name, model_var))
|
|
|
|
|
| if FLAGS.save_summaries_images:
|
| summary_image = graph.get_tensor_by_name(
|
| ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/'))
|
| summaries.add(
|
| tf.summary.image('samples/%s' % common.IMAGE, summary_image))
|
|
|
| first_clone_label = graph.get_tensor_by_name(
|
| ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))
|
|
|
| pixel_scaling = max(1, 255 // dataset.num_of_classes)
|
| summary_label = tf.cast(first_clone_label * pixel_scaling, tf.uint8)
|
| summaries.add(
|
| tf.summary.image('samples/%s' % common.LABEL, summary_label))
|
|
|
| first_clone_output = graph.get_tensor_by_name(
|
| ('%s/%s:0' % (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
|
| predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1)
|
|
|
| summary_predictions = tf.cast(predictions * pixel_scaling, tf.uint8)
|
| summaries.add(
|
| tf.summary.image(
|
| 'samples/%s' % common.OUTPUT_TYPE, summary_predictions))
|
|
|
|
|
| for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
|
| summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))
|
|
|
|
|
| with tf.device(config.optimizer_device()):
|
| learning_rate = train_utils.get_model_learning_rate(
|
| FLAGS.learning_policy,
|
| FLAGS.base_learning_rate,
|
| FLAGS.learning_rate_decay_step,
|
| FLAGS.learning_rate_decay_factor,
|
| FLAGS.training_number_of_steps,
|
| FLAGS.learning_power,
|
| FLAGS.slow_start_step,
|
| FLAGS.slow_start_learning_rate,
|
| decay_steps=FLAGS.decay_steps,
|
| end_learning_rate=FLAGS.end_learning_rate)
|
|
|
| summaries.add(tf.summary.scalar('learning_rate', learning_rate))
|
|
|
| if FLAGS.optimizer == 'momentum':
|
| optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
|
| elif FLAGS.optimizer == 'adam':
|
| optimizer = tf.train.AdamOptimizer(
|
| learning_rate=FLAGS.adam_learning_rate, epsilon=FLAGS.adam_epsilon)
|
| else:
|
| raise ValueError('Unknown optimizer')
|
|
|
| if FLAGS.quantize_delay_step >= 0:
|
| if FLAGS.num_clones > 1:
|
| raise ValueError('Quantization doesn\'t support multi-clone yet.')
|
| contrib_quantize.create_training_graph(
|
| quant_delay=FLAGS.quantize_delay_step)
|
|
|
| startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps
|
|
|
| with tf.device(config.variables_device()):
|
| total_loss, grads_and_vars = model_deploy.optimize_clones(
|
| clones, optimizer)
|
| total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
|
| summaries.add(tf.summary.scalar('total_loss', total_loss))
|
|
|
|
|
| last_layers = model.get_extra_layer_scopes(
|
| FLAGS.last_layers_contain_logits_only)
|
| grad_mult = train_utils.get_model_gradient_multipliers(
|
| last_layers, FLAGS.last_layer_gradient_multiplier)
|
| if grad_mult:
|
| grads_and_vars = slim.learning.multiply_gradients(
|
| grads_and_vars, grad_mult)
|
|
|
|
|
| grad_updates = optimizer.apply_gradients(
|
| grads_and_vars, global_step=global_step)
|
| update_ops.append(grad_updates)
|
| update_op = tf.group(*update_ops)
|
| with tf.control_dependencies([update_op]):
|
| train_tensor = tf.identity(total_loss, name='train_op')
|
|
|
|
|
|
|
| summaries |= set(
|
| tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
|
|
|
|
|
| summary_op = tf.summary.merge(list(summaries))
|
|
|
|
|
| session_config = tf.ConfigProto(
|
| allow_soft_placement=True, log_device_placement=False)
|
|
|
|
|
| profile_dir = FLAGS.profile_logdir
|
| if profile_dir is not None:
|
| tf.gfile.MakeDirs(profile_dir)
|
|
|
| with contrib_tfprof.ProfileContext(
|
| enabled=profile_dir is not None, profile_dir=profile_dir):
|
| init_fn = None
|
| if FLAGS.tf_initial_checkpoint:
|
| init_fn = train_utils.get_model_init_fn(
|
| FLAGS.train_logdir,
|
| FLAGS.tf_initial_checkpoint,
|
| FLAGS.initialize_last_layer,
|
| last_layers,
|
| ignore_missing_vars=True)
|
|
|
| slim.learning.train(
|
| train_tensor,
|
| logdir=FLAGS.train_logdir,
|
| log_every_n_steps=FLAGS.log_steps,
|
| master=FLAGS.master,
|
| number_of_steps=FLAGS.training_number_of_steps,
|
| is_chief=(FLAGS.task == 0),
|
| session_config=session_config,
|
| startup_delay_steps=startup_delay_steps,
|
| init_fn=init_fn,
|
| summary_op=summary_op,
|
| save_summaries_secs=FLAGS.save_summaries_secs,
|
| save_interval_secs=FLAGS.save_interval_secs)
|
|
|
|
|
| if __name__ == '__main__':
|
| flags.mark_flag_as_required('train_logdir')
|
| flags.mark_flag_as_required('dataset_dir')
|
| tf.app.run()
|
|
|