|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Training script for the DeepLab model. |
|
|
|
|
|
See model.py for more details and usage. |
|
|
""" |
|
|
|
|
|
from __future__ import absolute_import |
|
|
from __future__ import division |
|
|
from __future__ import print_function |
|
|
import six |
|
|
import tensorflow as tf |
|
|
from tensorflow.contrib import quantize as contrib_quantize |
|
|
from tensorflow.contrib import tfprof as contrib_tfprof |
|
|
from deeplab import common |
|
|
from deeplab import model |
|
|
from deeplab.datasets import data_generator |
|
|
from deeplab.utils import train_utils |
|
|
from deployment import model_deploy |
|
|
|
|
|
slim = tf.contrib.slim |
|
|
flags = tf.app.flags |
|
|
FLAGS = flags.FLAGS |
|
|
|
|
|
|
|
|
|
|
|
flags.DEFINE_integer('num_clones', 1, 'Number of clones to deploy.') |
|
|
|
|
|
flags.DEFINE_boolean('clone_on_cpu', False, 'Use CPUs to deploy clones.') |
|
|
|
|
|
flags.DEFINE_integer('num_replicas', 1, 'Number of worker replicas.') |
|
|
|
|
|
flags.DEFINE_integer('startup_delay_steps', 15, |
|
|
'Number of training steps between replicas startup.') |
|
|
|
|
|
flags.DEFINE_integer( |
|
|
'num_ps_tasks', 0, |
|
|
'The number of parameter servers. If the value is 0, then ' |
|
|
'the parameters are handled locally by the worker.') |
|
|
|
|
|
flags.DEFINE_string('master', '', 'BNS name of the tensorflow server') |
|
|
|
|
|
flags.DEFINE_integer('task', 0, 'The task ID.') |
|
|
|
|
|
|
|
|
|
|
|
flags.DEFINE_string('train_logdir', None, |
|
|
'Where the checkpoint and logs are stored.') |
|
|
|
|
|
flags.DEFINE_integer('log_steps', 10, |
|
|
'Display logging information at every log_steps.') |
|
|
|
|
|
flags.DEFINE_integer('save_interval_secs', 1200, |
|
|
'How often, in seconds, we save the model to disk.') |
|
|
|
|
|
flags.DEFINE_integer('save_summaries_secs', 600, |
|
|
'How often, in seconds, we compute the summaries.') |
|
|
|
|
|
flags.DEFINE_boolean( |
|
|
'save_summaries_images', False, |
|
|
'Save sample inputs, labels, and semantic predictions as ' |
|
|
'images to summary.') |
|
|
|
|
|
|
|
|
|
|
|
flags.DEFINE_string('profile_logdir', None, |
|
|
'Where the profile files are stored.') |
|
|
|
|
|
|
|
|
|
|
|
flags.DEFINE_enum('optimizer', 'momentum', ['momentum', 'adam'], |
|
|
'Which optimizer to use.') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flags.DEFINE_enum('learning_policy', 'poly', ['poly', 'step'], |
|
|
'Learning rate policy for training.') |
|
|
|
|
|
|
|
|
|
|
|
flags.DEFINE_float('base_learning_rate', .0001, |
|
|
'The base learning rate for model training.') |
|
|
|
|
|
flags.DEFINE_float('decay_steps', 0.0, |
|
|
'Decay steps for polynomial learning rate schedule.') |
|
|
|
|
|
flags.DEFINE_float('end_learning_rate', 0.0, |
|
|
'End learning rate for polynomial learning rate schedule.') |
|
|
|
|
|
flags.DEFINE_float('learning_rate_decay_factor', 0.1, |
|
|
'The rate to decay the base learning rate.') |
|
|
|
|
|
flags.DEFINE_integer('learning_rate_decay_step', 2000, |
|
|
'Decay the base learning rate at a fixed step.') |
|
|
|
|
|
flags.DEFINE_float('learning_power', 0.9, |
|
|
'The power value used in the poly learning policy.') |
|
|
|
|
|
flags.DEFINE_integer('training_number_of_steps', 30000, |
|
|
'The number of steps used for training') |
|
|
|
|
|
flags.DEFINE_float('momentum', 0.9, 'The momentum value to use') |
|
|
|
|
|
|
|
|
flags.DEFINE_float('adam_learning_rate', 0.001, |
|
|
'Learning rate for the adam optimizer.') |
|
|
flags.DEFINE_float('adam_epsilon', 1e-08, 'Adam optimizer epsilon.') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flags.DEFINE_integer('train_batch_size', 8, |
|
|
'The number of images in each batch during training.') |
|
|
|
|
|
|
|
|
|
|
|
flags.DEFINE_float('weight_decay', 0.00004, |
|
|
'The value of the weight decay for training.') |
|
|
|
|
|
flags.DEFINE_list('train_crop_size', '513,513', |
|
|
'Image crop size [height, width] during training.') |
|
|
|
|
|
flags.DEFINE_float( |
|
|
'last_layer_gradient_multiplier', 1.0, |
|
|
'The gradient multiplier for last layers, which is used to ' |
|
|
'boost the gradient of last layers if the value > 1.') |
|
|
|
|
|
flags.DEFINE_boolean('upsample_logits', True, |
|
|
'Upsample logits during training.') |
|
|
|
|
|
|
|
|
|
|
|
flags.DEFINE_float( |
|
|
'drop_path_keep_prob', 1.0, |
|
|
'Probability to keep each path in the NAS cell when training.') |
|
|
|
|
|
|
|
|
|
|
|
flags.DEFINE_string('tf_initial_checkpoint', None, |
|
|
'The initial checkpoint in tensorflow format.') |
|
|
|
|
|
|
|
|
flags.DEFINE_boolean('initialize_last_layer', True, |
|
|
'Initialize the last layer.') |
|
|
|
|
|
flags.DEFINE_boolean('last_layers_contain_logits_only', False, |
|
|
'Only consider logits as last layers or not.') |
|
|
|
|
|
flags.DEFINE_integer('slow_start_step', 0, |
|
|
'Training model with small learning rate for few steps.') |
|
|
|
|
|
flags.DEFINE_float('slow_start_learning_rate', 1e-4, |
|
|
'Learning rate employed during slow start.') |
|
|
|
|
|
|
|
|
|
|
|
flags.DEFINE_boolean('fine_tune_batch_norm', True, |
|
|
'Fine tune the batch norm parameters or not.') |
|
|
|
|
|
flags.DEFINE_float('min_scale_factor', 0.5, |
|
|
'Mininum scale factor for data augmentation.') |
|
|
|
|
|
flags.DEFINE_float('max_scale_factor', 2., |
|
|
'Maximum scale factor for data augmentation.') |
|
|
|
|
|
flags.DEFINE_float('scale_factor_step_size', 0.25, |
|
|
'Scale factor step size for data augmentation.') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flags.DEFINE_multi_integer('atrous_rates', None, |
|
|
'Atrous rates for atrous spatial pyramid pooling.') |
|
|
|
|
|
flags.DEFINE_integer('output_stride', 16, |
|
|
'The ratio of input to output spatial resolution.') |
|
|
|
|
|
|
|
|
flags.DEFINE_integer( |
|
|
'hard_example_mining_step', 0, |
|
|
'The training step in which exact hard example mining kicks off. Note we ' |
|
|
'gradually reduce the mining percent to the specified ' |
|
|
'top_k_percent_pixels. For example, if hard_example_mining_step=100K and ' |
|
|
'top_k_percent_pixels=0.25, then mining percent will gradually reduce from ' |
|
|
'100% to 25% until 100K steps after which we only mine top 25% pixels.') |
|
|
|
|
|
flags.DEFINE_float( |
|
|
'top_k_percent_pixels', 1.0, |
|
|
'The top k percent pixels (in terms of the loss values) used to compute ' |
|
|
'loss during training. This is useful for hard pixel mining.') |
|
|
|
|
|
|
|
|
flags.DEFINE_integer( |
|
|
'quantize_delay_step', -1, |
|
|
'Steps to start quantized training. If < 0, will not quantize model.') |
|
|
|
|
|
|
|
|
flags.DEFINE_string('dataset', 'pascal_voc_seg', |
|
|
'Name of the segmentation dataset.') |
|
|
|
|
|
flags.DEFINE_string('train_split', 'train', |
|
|
'Which split of the dataset to be used for training') |
|
|
|
|
|
flags.DEFINE_string('dataset_dir', None, 'Where the dataset reside.') |
|
|
|
|
|
|
|
|
def _build_deeplab(iterator, outputs_to_num_classes, ignore_label): |
|
|
"""Builds a clone of DeepLab. |
|
|
|
|
|
Args: |
|
|
iterator: An iterator of type tf.data.Iterator for images and labels. |
|
|
outputs_to_num_classes: A map from output type to the number of classes. For |
|
|
example, for the task of semantic segmentation with 21 semantic classes, |
|
|
we would have outputs_to_num_classes['semantic'] = 21. |
|
|
ignore_label: Ignore label. |
|
|
""" |
|
|
samples = iterator.get_next() |
|
|
|
|
|
|
|
|
samples[common.IMAGE] = tf.identity(samples[common.IMAGE], name=common.IMAGE) |
|
|
samples[common.LABEL] = tf.identity(samples[common.LABEL], name=common.LABEL) |
|
|
|
|
|
model_options = common.ModelOptions( |
|
|
outputs_to_num_classes=outputs_to_num_classes, |
|
|
crop_size=[int(sz) for sz in FLAGS.train_crop_size], |
|
|
atrous_rates=FLAGS.atrous_rates, |
|
|
output_stride=FLAGS.output_stride) |
|
|
|
|
|
outputs_to_scales_to_logits = model.multi_scale_logits( |
|
|
samples[common.IMAGE], |
|
|
model_options=model_options, |
|
|
image_pyramid=FLAGS.image_pyramid, |
|
|
weight_decay=FLAGS.weight_decay, |
|
|
is_training=True, |
|
|
fine_tune_batch_norm=FLAGS.fine_tune_batch_norm, |
|
|
nas_training_hyper_parameters={ |
|
|
'drop_path_keep_prob': FLAGS.drop_path_keep_prob, |
|
|
'total_training_steps': FLAGS.training_number_of_steps, |
|
|
}) |
|
|
|
|
|
|
|
|
output_type_dict = outputs_to_scales_to_logits[common.OUTPUT_TYPE] |
|
|
output_type_dict[model.MERGED_LOGITS_SCOPE] = tf.identity( |
|
|
output_type_dict[model.MERGED_LOGITS_SCOPE], name=common.OUTPUT_TYPE) |
|
|
|
|
|
for output, num_classes in six.iteritems(outputs_to_num_classes): |
|
|
train_utils.add_softmax_cross_entropy_loss_for_each_scale( |
|
|
outputs_to_scales_to_logits[output], |
|
|
samples[common.LABEL], |
|
|
num_classes, |
|
|
ignore_label, |
|
|
loss_weight=model_options.label_weights, |
|
|
upsample_logits=FLAGS.upsample_logits, |
|
|
hard_example_mining_step=FLAGS.hard_example_mining_step, |
|
|
top_k_percent_pixels=FLAGS.top_k_percent_pixels, |
|
|
scope=output) |
|
|
|
|
|
|
|
|
def main(unused_argv): |
|
|
tf.logging.set_verbosity(tf.logging.INFO) |
|
|
|
|
|
config = model_deploy.DeploymentConfig( |
|
|
num_clones=FLAGS.num_clones, |
|
|
clone_on_cpu=FLAGS.clone_on_cpu, |
|
|
replica_id=FLAGS.task, |
|
|
num_replicas=FLAGS.num_replicas, |
|
|
num_ps_tasks=FLAGS.num_ps_tasks) |
|
|
|
|
|
|
|
|
assert FLAGS.train_batch_size % config.num_clones == 0, ( |
|
|
'Training batch size not divisble by number of clones (GPUs).') |
|
|
|
|
|
clone_batch_size = FLAGS.train_batch_size // config.num_clones |
|
|
|
|
|
tf.gfile.MakeDirs(FLAGS.train_logdir) |
|
|
tf.logging.info('Training on %s set', FLAGS.train_split) |
|
|
|
|
|
with tf.Graph().as_default() as graph: |
|
|
with tf.device(config.inputs_device()): |
|
|
dataset = data_generator.Dataset( |
|
|
dataset_name=FLAGS.dataset, |
|
|
split_name=FLAGS.train_split, |
|
|
dataset_dir=FLAGS.dataset_dir, |
|
|
batch_size=clone_batch_size, |
|
|
crop_size=[int(sz) for sz in FLAGS.train_crop_size], |
|
|
min_resize_value=FLAGS.min_resize_value, |
|
|
max_resize_value=FLAGS.max_resize_value, |
|
|
resize_factor=FLAGS.resize_factor, |
|
|
min_scale_factor=FLAGS.min_scale_factor, |
|
|
max_scale_factor=FLAGS.max_scale_factor, |
|
|
scale_factor_step_size=FLAGS.scale_factor_step_size, |
|
|
model_variant=FLAGS.model_variant, |
|
|
num_readers=4, |
|
|
is_training=True, |
|
|
should_shuffle=True, |
|
|
should_repeat=True) |
|
|
|
|
|
|
|
|
with tf.device(config.variables_device()): |
|
|
global_step = tf.train.get_or_create_global_step() |
|
|
|
|
|
|
|
|
model_fn = _build_deeplab |
|
|
model_args = (dataset.get_one_shot_iterator(), { |
|
|
common.OUTPUT_TYPE: dataset.num_of_classes |
|
|
}, dataset.ignore_label) |
|
|
clones = model_deploy.create_clones(config, model_fn, args=model_args) |
|
|
|
|
|
|
|
|
|
|
|
first_clone_scope = config.clone_scope(0) |
|
|
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) |
|
|
|
|
|
|
|
|
summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) |
|
|
|
|
|
|
|
|
for model_var in tf.model_variables(): |
|
|
summaries.add(tf.summary.histogram(model_var.op.name, model_var)) |
|
|
|
|
|
|
|
|
if FLAGS.save_summaries_images: |
|
|
summary_image = graph.get_tensor_by_name( |
|
|
('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/')) |
|
|
summaries.add( |
|
|
tf.summary.image('samples/%s' % common.IMAGE, summary_image)) |
|
|
|
|
|
first_clone_label = graph.get_tensor_by_name( |
|
|
('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/')) |
|
|
|
|
|
pixel_scaling = max(1, 255 // dataset.num_of_classes) |
|
|
summary_label = tf.cast(first_clone_label * pixel_scaling, tf.uint8) |
|
|
summaries.add( |
|
|
tf.summary.image('samples/%s' % common.LABEL, summary_label)) |
|
|
|
|
|
first_clone_output = graph.get_tensor_by_name( |
|
|
('%s/%s:0' % (first_clone_scope, common.OUTPUT_TYPE)).strip('/')) |
|
|
predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1) |
|
|
|
|
|
summary_predictions = tf.cast(predictions * pixel_scaling, tf.uint8) |
|
|
summaries.add( |
|
|
tf.summary.image( |
|
|
'samples/%s' % common.OUTPUT_TYPE, summary_predictions)) |
|
|
|
|
|
|
|
|
for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope): |
|
|
summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss)) |
|
|
|
|
|
|
|
|
with tf.device(config.optimizer_device()): |
|
|
learning_rate = train_utils.get_model_learning_rate( |
|
|
FLAGS.learning_policy, |
|
|
FLAGS.base_learning_rate, |
|
|
FLAGS.learning_rate_decay_step, |
|
|
FLAGS.learning_rate_decay_factor, |
|
|
FLAGS.training_number_of_steps, |
|
|
FLAGS.learning_power, |
|
|
FLAGS.slow_start_step, |
|
|
FLAGS.slow_start_learning_rate, |
|
|
decay_steps=FLAGS.decay_steps, |
|
|
end_learning_rate=FLAGS.end_learning_rate) |
|
|
|
|
|
summaries.add(tf.summary.scalar('learning_rate', learning_rate)) |
|
|
|
|
|
if FLAGS.optimizer == 'momentum': |
|
|
optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum) |
|
|
elif FLAGS.optimizer == 'adam': |
|
|
optimizer = tf.train.AdamOptimizer( |
|
|
learning_rate=FLAGS.adam_learning_rate, epsilon=FLAGS.adam_epsilon) |
|
|
else: |
|
|
raise ValueError('Unknown optimizer') |
|
|
|
|
|
if FLAGS.quantize_delay_step >= 0: |
|
|
if FLAGS.num_clones > 1: |
|
|
raise ValueError('Quantization doesn\'t support multi-clone yet.') |
|
|
contrib_quantize.create_training_graph( |
|
|
quant_delay=FLAGS.quantize_delay_step) |
|
|
|
|
|
startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps |
|
|
|
|
|
with tf.device(config.variables_device()): |
|
|
total_loss, grads_and_vars = model_deploy.optimize_clones( |
|
|
clones, optimizer) |
|
|
total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.') |
|
|
summaries.add(tf.summary.scalar('total_loss', total_loss)) |
|
|
|
|
|
|
|
|
last_layers = model.get_extra_layer_scopes( |
|
|
FLAGS.last_layers_contain_logits_only) |
|
|
grad_mult = train_utils.get_model_gradient_multipliers( |
|
|
last_layers, FLAGS.last_layer_gradient_multiplier) |
|
|
if grad_mult: |
|
|
grads_and_vars = slim.learning.multiply_gradients( |
|
|
grads_and_vars, grad_mult) |
|
|
|
|
|
|
|
|
grad_updates = optimizer.apply_gradients( |
|
|
grads_and_vars, global_step=global_step) |
|
|
update_ops.append(grad_updates) |
|
|
update_op = tf.group(*update_ops) |
|
|
with tf.control_dependencies([update_op]): |
|
|
train_tensor = tf.identity(total_loss, name='train_op') |
|
|
|
|
|
|
|
|
|
|
|
summaries |= set( |
|
|
tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) |
|
|
|
|
|
|
|
|
summary_op = tf.summary.merge(list(summaries)) |
|
|
|
|
|
|
|
|
session_config = tf.ConfigProto( |
|
|
allow_soft_placement=True, log_device_placement=False) |
|
|
|
|
|
|
|
|
profile_dir = FLAGS.profile_logdir |
|
|
if profile_dir is not None: |
|
|
tf.gfile.MakeDirs(profile_dir) |
|
|
|
|
|
with contrib_tfprof.ProfileContext( |
|
|
enabled=profile_dir is not None, profile_dir=profile_dir): |
|
|
init_fn = None |
|
|
if FLAGS.tf_initial_checkpoint: |
|
|
init_fn = train_utils.get_model_init_fn( |
|
|
FLAGS.train_logdir, |
|
|
FLAGS.tf_initial_checkpoint, |
|
|
FLAGS.initialize_last_layer, |
|
|
last_layers, |
|
|
ignore_missing_vars=True) |
|
|
|
|
|
slim.learning.train( |
|
|
train_tensor, |
|
|
logdir=FLAGS.train_logdir, |
|
|
log_every_n_steps=FLAGS.log_steps, |
|
|
master=FLAGS.master, |
|
|
number_of_steps=FLAGS.training_number_of_steps, |
|
|
is_chief=(FLAGS.task == 0), |
|
|
session_config=session_config, |
|
|
startup_delay_steps=startup_delay_steps, |
|
|
init_fn=init_fn, |
|
|
summary_op=summary_op, |
|
|
save_summaries_secs=FLAGS.save_summaries_secs, |
|
|
save_interval_secs=FLAGS.save_interval_secs) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
flags.mark_flag_as_required('train_logdir') |
|
|
flags.mark_flag_as_required('dataset_dir') |
|
|
tf.app.run() |
|
|
|