|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Define flags are common for both train.py and eval.py scripts."""
|
| import logging
|
| import sys
|
|
|
| from tensorflow.compat.v1 import flags
|
|
|
| import datasets
|
| import model
|
|
|
| FLAGS = flags.FLAGS
|
|
|
| logging.basicConfig(
|
| level=logging.DEBUG,
|
| stream=sys.stderr,
|
| format='%(levelname)s '
|
| '%(asctime)s.%(msecs)06d: '
|
| '%(filename)s: '
|
| '%(lineno)d '
|
| '%(message)s',
|
| datefmt='%Y-%m-%d %H:%M:%S')
|
|
|
|
|
| _common_flags_defined = False
|
|
|
| def define():
|
| """Define common flags."""
|
|
|
|
|
| global _common_flags_defined
|
| if _common_flags_defined:
|
| return
|
| _common_flags_defined = True
|
|
|
| flags.DEFINE_integer('batch_size', 32,
|
| 'Batch size.')
|
|
|
| flags.DEFINE_integer('crop_width', None,
|
| 'Width of the central crop for images.')
|
|
|
| flags.DEFINE_integer('crop_height', None,
|
| 'Height of the central crop for images.')
|
|
|
| flags.DEFINE_string('train_log_dir', '/tmp/attention_ocr/train',
|
| 'Directory where to write event logs.')
|
|
|
| flags.DEFINE_string('dataset_name', 'fsns',
|
| 'Name of the dataset. Supported: fsns')
|
|
|
| flags.DEFINE_string('split_name', 'train',
|
| 'Dataset split name to run evaluation for: test,train.')
|
|
|
| flags.DEFINE_string('dataset_dir', None,
|
| 'Dataset root folder.')
|
|
|
| flags.DEFINE_string('checkpoint', '',
|
| 'Path for checkpoint to restore weights from.')
|
|
|
| flags.DEFINE_string('master',
|
| '',
|
| 'BNS name of the TensorFlow master to use.')
|
|
|
|
|
| flags.DEFINE_float('learning_rate', 0.004,
|
| 'learning rate')
|
|
|
| flags.DEFINE_string('optimizer', 'momentum',
|
| 'the optimizer to use')
|
|
|
| flags.DEFINE_float('momentum', 0.9,
|
| 'momentum value for the momentum optimizer if used')
|
|
|
| flags.DEFINE_bool('use_augment_input', True,
|
| 'If True will use image augmentation')
|
|
|
|
|
|
|
| flags.DEFINE_string('final_endpoint', 'Mixed_5d',
|
| 'Endpoint to cut inception tower')
|
|
|
|
|
| flags.DEFINE_bool('use_attention', True,
|
| 'If True will use the attention mechanism')
|
|
|
| flags.DEFINE_bool('use_autoregression', True,
|
| 'If True will use autoregression (a feedback link)')
|
|
|
| flags.DEFINE_integer('num_lstm_units', 256,
|
| 'number of LSTM units for sequence LSTM')
|
|
|
| flags.DEFINE_float('weight_decay', 0.00004,
|
| 'weight decay for char prediction FC layers')
|
|
|
| flags.DEFINE_float('lstm_state_clip_value', 10.0,
|
| 'cell state is clipped by this value prior to the cell'
|
| ' output activation')
|
|
|
|
|
| flags.DEFINE_float('label_smoothing', 0.1,
|
| 'weight for label smoothing')
|
|
|
| flags.DEFINE_bool('ignore_nulls', True,
|
| 'ignore null characters for computing the loss')
|
|
|
| flags.DEFINE_bool('average_across_timesteps', False,
|
| 'divide the returned cost by the total label weight')
|
|
|
|
|
|
|
| def get_crop_size():
|
| if FLAGS.crop_width and FLAGS.crop_height:
|
| return (FLAGS.crop_width, FLAGS.crop_height)
|
| else:
|
| return None
|
|
|
|
|
| def create_dataset(split_name):
|
| ds_module = getattr(datasets, FLAGS.dataset_name)
|
| return ds_module.get_split(split_name, dataset_dir=FLAGS.dataset_dir)
|
|
|
|
|
| def create_mparams():
|
| return {
|
| 'conv_tower_fn':
|
| model.ConvTowerParams(final_endpoint=FLAGS.final_endpoint),
|
| 'sequence_logit_fn':
|
| model.SequenceLogitsParams(
|
| use_attention=FLAGS.use_attention,
|
| use_autoregression=FLAGS.use_autoregression,
|
| num_lstm_units=FLAGS.num_lstm_units,
|
| weight_decay=FLAGS.weight_decay,
|
| lstm_state_clip_value=FLAGS.lstm_state_clip_value),
|
| 'sequence_loss_fn':
|
| model.SequenceLossParams(
|
| label_smoothing=FLAGS.label_smoothing,
|
| ignore_nulls=FLAGS.ignore_nulls,
|
| average_across_timesteps=FLAGS.average_across_timesteps)
|
| }
|
|
|
|
|
| def create_model(*args, **kwargs):
|
| ocr_model = model.Model(mparams=create_mparams(), *args, **kwargs)
|
| return ocr_model
|
|
|