| | import tensorflow as tf |
| | import numpy as np |
| | from baselines.common.tf_util import get_session |
| |
|
| | class RunningMeanStd(object): |
| | |
| | def __init__(self, epsilon=1e-4, shape=()): |
| | self.mean = np.zeros(shape, 'float64') |
| | self.var = np.ones(shape, 'float64') |
| | self.count = epsilon |
| |
|
| | def update(self, x): |
| | batch_mean = np.mean(x, axis=0) |
| | batch_var = np.var(x, axis=0) |
| | batch_count = x.shape[0] |
| | self.update_from_moments(batch_mean, batch_var, batch_count) |
| |
|
| | def update_from_moments(self, batch_mean, batch_var, batch_count): |
| | self.mean, self.var, self.count = update_mean_var_count_from_moments( |
| | self.mean, self.var, self.count, batch_mean, batch_var, batch_count) |
| |
|
| | def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count): |
| | delta = batch_mean - mean |
| | tot_count = count + batch_count |
| |
|
| | new_mean = mean + delta * batch_count / tot_count |
| | m_a = var * count |
| | m_b = batch_var * batch_count |
| | M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count |
| | new_var = M2 / tot_count |
| | new_count = tot_count |
| |
|
| | return new_mean, new_var, new_count |
| |
|
| |
|
| | class TfRunningMeanStd(object): |
| | |
| | ''' |
| | TensorFlow variables-based implmentation of computing running mean and std |
| | Benefit of this implementation is that it can be saved / loaded together with the tensorflow model |
| | ''' |
| | def __init__(self, epsilon=1e-4, shape=(), scope=''): |
| | sess = get_session() |
| |
|
| | self._new_mean = tf.compat.v1.placeholder(shape=shape, dtype=tf.float64) |
| | self._new_var = tf.compat.v1.placeholder(shape=shape, dtype=tf.float64) |
| | self._new_count = tf.compat.v1.placeholder(shape=(), dtype=tf.float64) |
| |
|
| |
|
| | with tf.compat.v1.variable_scope(scope, reuse=tf.compat.v1.AUTO_REUSE): |
| | self._mean = tf.compat.v1.get_variable('mean', initializer=np.zeros(shape, 'float64'), dtype=tf.float64) |
| | self._var = tf.compat.v1.get_variable('std', initializer=np.ones(shape, 'float64'), dtype=tf.float64) |
| | self._count = tf.compat.v1.get_variable('count', initializer=np.full((), epsilon, 'float64'), dtype=tf.float64) |
| |
|
| | self.update_ops = tf.group([ |
| | self._var.assign(self._new_var), |
| | self._mean.assign(self._new_mean), |
| | self._count.assign(self._new_count) |
| | ]) |
| |
|
| | sess.run(tf.compat.v1.variables_initializer([self._mean, self._var, self._count])) |
| | self.sess = sess |
| | self._set_mean_var_count() |
| |
|
| | def _set_mean_var_count(self): |
| | self.mean, self.var, self.count = self.sess.run([self._mean, self._var, self._count]) |
| |
|
| | def update(self, x): |
| | batch_mean = np.mean(x, axis=0) |
| | batch_var = np.var(x, axis=0) |
| | batch_count = x.shape[0] |
| |
|
| | new_mean, new_var, new_count = update_mean_var_count_from_moments(self.mean, self.var, self.count, batch_mean, batch_var, batch_count) |
| |
|
| | self.sess.run(self.update_ops, feed_dict={ |
| | self._new_mean: new_mean, |
| | self._new_var: new_var, |
| | self._new_count: new_count |
| | }) |
| |
|
| | self._set_mean_var_count() |
| |
|
| |
|
| |
|
| | def test_runningmeanstd(): |
| | for (x1, x2, x3) in [ |
| | (np.random.randn(3), np.random.randn(4), np.random.randn(5)), |
| | (np.random.randn(3,2), np.random.randn(4,2), np.random.randn(5,2)), |
| | ]: |
| |
|
| | rms = RunningMeanStd(epsilon=0.0, shape=x1.shape[1:]) |
| |
|
| | x = np.concatenate([x1, x2, x3], axis=0) |
| | ms1 = [x.mean(axis=0), x.var(axis=0)] |
| | rms.update(x1) |
| | rms.update(x2) |
| | rms.update(x3) |
| | ms2 = [rms.mean, rms.var] |
| |
|
| | np.testing.assert_allclose(ms1, ms2) |
| |
|
| | def test_tf_runningmeanstd(): |
| | for (x1, x2, x3) in [ |
| | (np.random.randn(3), np.random.randn(4), np.random.randn(5)), |
| | (np.random.randn(3,2), np.random.randn(4,2), np.random.randn(5,2)), |
| | ]: |
| |
|
| | rms = TfRunningMeanStd(epsilon=0.0, shape=x1.shape[1:], scope='running_mean_std' + str(np.random.randint(0, 128))) |
| |
|
| | x = np.concatenate([x1, x2, x3], axis=0) |
| | ms1 = [x.mean(axis=0), x.var(axis=0)] |
| | rms.update(x1) |
| | rms.update(x2) |
| | rms.update(x3) |
| | ms2 = [rms.mean, rms.var] |
| |
|
| | np.testing.assert_allclose(ms1, ms2) |
| |
|
| |
|
| | def profile_tf_runningmeanstd(): |
| | import time |
| | from baselines.common import tf_util |
| |
|
| | tf_util.get_session( config=tf.compat.v1.ConfigProto( |
| | inter_op_parallelism_threads=1, |
| | intra_op_parallelism_threads=1, |
| | allow_soft_placement=True |
| | )) |
| |
|
| | x = np.random.random((376,)) |
| |
|
| | n_trials = 10000 |
| | rms = RunningMeanStd() |
| | tfrms = TfRunningMeanStd() |
| |
|
| | tic1 = time.time() |
| | for _ in range(n_trials): |
| | rms.update(x) |
| |
|
| | tic2 = time.time() |
| | for _ in range(n_trials): |
| | tfrms.update(x) |
| |
|
| | tic3 = time.time() |
| |
|
| | print('rms update time ({} trials): {} s'.format(n_trials, tic2 - tic1)) |
| | print('tfrms update time ({} trials): {} s'.format(n_trials, tic3 - tic2)) |
| |
|
| |
|
| | tic1 = time.time() |
| | for _ in range(n_trials): |
| | z1 = rms.mean |
| |
|
| | tic2 = time.time() |
| | for _ in range(n_trials): |
| | z2 = tfrms.mean |
| |
|
| | assert z1 == z2 |
| |
|
| | tic3 = time.time() |
| |
|
| | print('rms get mean time ({} trials): {} s'.format(n_trials, tic2 - tic1)) |
| | print('tfrms get mean time ({} trials): {} s'.format(n_trials, tic3 - tic2)) |
| |
|
| |
|
| |
|
| | ''' |
| | options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) #pylint: disable=E1101 |
| | run_metadata = tf.RunMetadata() |
| | profile_opts = dict(options=options, run_metadata=run_metadata) |
| | |
| | |
| | |
| | from tensorflow.python.client import timeline |
| | fetched_timeline = timeline.Timeline(run_metadata.step_stats) #pylint: disable=E1101 |
| | chrome_trace = fetched_timeline.generate_chrome_trace_format() |
| | outfile = '/tmp/timeline.json' |
| | with open(outfile, 'wt') as f: |
| | f.write(chrome_trace) |
| | print('Successfully saved profile to {}. Exiting.'.format(outfile)) |
| | exit(0) |
| | ''' |
| |
|
| |
|
| |
|
| | if __name__ == '__main__': |
| | profile_tf_runningmeanstd() |
| |
|