|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Contains common flags and functions."""
|
|
|
| from __future__ import absolute_import
|
| from __future__ import division
|
| from __future__ import print_function
|
|
|
| import locale
|
| import os
|
| from absl import logging
|
| import numpy as np
|
| import tensorflow as tf
|
|
|
|
|
| def get_seq_middle(seq_length):
|
| """Returns relative index for the middle frame in sequence."""
|
| half_offset = int((seq_length - 1) / 2)
|
| return seq_length - 1 - half_offset
|
|
|
|
|
| def info(obj):
|
| """Return info on shape and dtype of a numpy array or TensorFlow tensor."""
|
| if obj is None:
|
| return 'None.'
|
| elif isinstance(obj, list):
|
| if obj:
|
| return 'List of %d... %s' % (len(obj), info(obj[0]))
|
| else:
|
| return 'Empty list.'
|
| elif isinstance(obj, tuple):
|
| if obj:
|
| return 'Tuple of %d... %s' % (len(obj), info(obj[0]))
|
| else:
|
| return 'Empty tuple.'
|
| else:
|
| if is_a_numpy_array(obj):
|
| return 'Array with shape: %s, dtype: %s' % (obj.shape, obj.dtype)
|
| else:
|
| return str(obj)
|
|
|
|
|
| def is_a_numpy_array(obj):
|
| """Returns true if obj is a numpy array."""
|
| return type(obj).__module__ == np.__name__
|
|
|
|
|
| def count_parameters(also_print=True):
|
| """Cound the number of parameters in the model.
|
|
|
| Args:
|
| also_print: Boolean. If True also print the numbers.
|
|
|
| Returns:
|
| The total number of parameters.
|
| """
|
| total = 0
|
| if also_print:
|
| logging.info('Model Parameters:')
|
| for v in get_vars_to_restore():
|
| shape = v.get_shape()
|
| if also_print:
|
| logging.info('%s %s: %s', v.op.name, shape,
|
| format_number(shape.num_elements()))
|
| total += shape.num_elements()
|
| if also_print:
|
| logging.info('Total: %s', format_number(total))
|
| return total
|
|
|
|
|
| def get_vars_to_restore(ckpt=None):
|
| """Returns list of variables that should be saved/restored.
|
|
|
| Args:
|
| ckpt: Path to existing checkpoint. If present, returns only the subset of
|
| variables that exist in given checkpoint.
|
|
|
| Returns:
|
| List of all variables that need to be saved/restored.
|
| """
|
| model_vars = tf.trainable_variables()
|
|
|
| bn_vars = [v for v in tf.global_variables()
|
| if 'moving_mean' in v.op.name or 'moving_variance' in v.op.name]
|
| model_vars.extend(bn_vars)
|
| model_vars = sorted(model_vars, key=lambda x: x.op.name)
|
| if ckpt is not None:
|
| ckpt_var_names = tf.contrib.framework.list_variables(ckpt)
|
| ckpt_var_names = [name for (name, unused_shape) in ckpt_var_names]
|
| for v in model_vars:
|
| if v.op.name not in ckpt_var_names:
|
| logging.warn('Missing var %s in checkpoint: %s', v.op.name,
|
| os.path.basename(ckpt))
|
| model_vars = [v for v in model_vars if v.op.name in ckpt_var_names]
|
| return model_vars
|
|
|
|
|
| def format_number(n):
|
| """Formats number with thousands commas."""
|
| locale.setlocale(locale.LC_ALL, 'en_US')
|
| return locale.format('%d', n, grouping=True)
|
|
|
|
|
| def read_text_lines(filepath):
|
| with open(filepath, 'r') as f:
|
| lines = f.readlines()
|
| lines = [l.rstrip() for l in lines]
|
| return lines
|
|
|