import numpy as np import tensorflow as tf from baselines.common import tf_util as U from baselines.common.tests.test_with_mpi import with_mpi from baselines import logger try: from mpi4py import MPI except ImportError: MPI = None class MpiAdamOptimizer(tf.compat.v1.train.AdamOptimizer): """Adam optimizer that averages gradients across mpi processes.""" def __init__(self, comm, grad_clip=None, mpi_rank_weight=1, **kwargs): self.comm = comm self.grad_clip = grad_clip self.mpi_rank_weight = mpi_rank_weight tf.compat.v1.train.AdamOptimizer.__init__(self, **kwargs) def compute_gradients(self, loss, var_list, **kwargs): grads_and_vars = tf.compat.v1.train.AdamOptimizer.compute_gradients(self, loss, var_list, **kwargs) grads_and_vars = [(g, v) for g, v in grads_and_vars if g is not None] flat_grad = tf.concat([tf.reshape(g, (-1,)) for g, v in grads_and_vars], axis=0) * self.mpi_rank_weight shapes = [v.shape.as_list() for g, v in grads_and_vars] sizes = [int(np.prod(s)) for s in shapes] total_weight = np.zeros(1, np.float32) self.comm.Allreduce(np.array([self.mpi_rank_weight], dtype=np.float32), total_weight, op=MPI.SUM) total_weight = total_weight[0] buf = np.zeros(sum(sizes), np.float32) countholder = [0] # Counts how many times _collect_grads has been called stat = tf.reduce_sum(input_tensor=grads_and_vars[0][1]) # sum of first variable def _collect_grads(flat_grad, np_stat): if self.grad_clip is not None: gradnorm = np.linalg.norm(flat_grad) if gradnorm > 1: flat_grad /= gradnorm logger.logkv_mean('gradnorm', gradnorm) logger.logkv_mean('gradclipfrac', float(gradnorm > 1)) self.comm.Allreduce(flat_grad, buf, op=MPI.SUM) np.divide(buf, float(total_weight), out=buf) if countholder[0] % 100 == 0: check_synced(np_stat, self.comm) countholder[0] += 1 return buf avg_flat_grad = tf.compat.v1.py_func(_collect_grads, [flat_grad, stat], tf.float32) avg_flat_grad.set_shape(flat_grad.shape) avg_grads = tf.split(avg_flat_grad, sizes, axis=0) avg_grads_and_vars = [(tf.reshape(g, v.shape), v) for g, (_, v) in zip(avg_grads, grads_and_vars)] return avg_grads_and_vars def check_synced(localval, comm=None): """ It's common to forget to initialize your variables to the same values, or (less commonly) if you update them in some other way than adam, to get them out of sync. This function checks that variables on all MPI workers are the same, and raises an AssertionError otherwise Arguments: comm: MPI communicator localval: list of local variables (list of variables on current worker to be compared with the other workers) """ comm = comm or MPI.COMM_WORLD vals = comm.gather(localval) if comm.rank == 0: assert all(val==vals[0] for val in vals[1:]),\ 'MpiAdamOptimizer detected that different workers have different weights: {}'.format(vals) @with_mpi(timeout=5) def test_nonfreeze(): np.random.seed(0) tf.compat.v1.set_random_seed(0) a = tf.Variable(np.random.randn(3).astype('float32')) b = tf.Variable(np.random.randn(2,5).astype('float32')) loss = tf.reduce_sum(input_tensor=tf.square(a)) + tf.reduce_sum(input_tensor=tf.sin(b)) stepsize = 1e-2 # for some reason the session config with inter_op_parallelism_threads was causing # nested sess.run calls to freeze config = tf.compat.v1.ConfigProto(inter_op_parallelism_threads=1) sess = U.get_session(config=config) update_op = MpiAdamOptimizer(comm=MPI.COMM_WORLD, learning_rate=stepsize).minimize(loss) sess.run(tf.compat.v1.global_variables_initializer()) losslist_ref = [] for i in range(100): l,_ = sess.run([loss, update_op]) print(i, l) losslist_ref.append(l)