|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Trust region optimization.
|
|
|
| A lot of this is adapted from other's code.
|
| See Schulman's Modular RL, wojzaremba's TRPO, etc.
|
|
|
| """
|
|
|
| from __future__ import absolute_import
|
| from __future__ import division
|
| from __future__ import print_function
|
|
|
| from six.moves import xrange
|
| import tensorflow as tf
|
| import numpy as np
|
|
|
|
|
| def var_size(v):
|
| return int(np.prod([int(d) for d in v.shape]))
|
|
|
|
|
| def gradients(loss, var_list):
|
| grads = tf.gradients(loss, var_list)
|
| return [g if g is not None else tf.zeros(v.shape)
|
| for g, v in zip(grads, var_list)]
|
|
|
| def flatgrad(loss, var_list):
|
| grads = gradients(loss, var_list)
|
| return tf.concat([tf.reshape(grad, [-1])
|
| for (v, grad) in zip(var_list, grads)
|
| if grad is not None], 0)
|
|
|
|
|
| def get_flat(var_list):
|
| return tf.concat([tf.reshape(v, [-1]) for v in var_list], 0)
|
|
|
|
|
| def set_from_flat(var_list, flat_theta):
|
| assigns = []
|
| shapes = [v.shape for v in var_list]
|
| sizes = [var_size(v) for v in var_list]
|
|
|
| start = 0
|
| assigns = []
|
| for (shape, size, v) in zip(shapes, sizes, var_list):
|
| assigns.append(v.assign(
|
| tf.reshape(flat_theta[start:start + size], shape)))
|
| start += size
|
| assert start == sum(sizes)
|
|
|
| return tf.group(*assigns)
|
|
|
|
|
| class TrustRegionOptimization(object):
|
|
|
| def __init__(self, max_divergence=0.1, cg_damping=0.1):
|
| self.max_divergence = max_divergence
|
| self.cg_damping = cg_damping
|
|
|
| def setup_placeholders(self):
|
| self.flat_tangent = tf.placeholder(tf.float32, [None], 'flat_tangent')
|
| self.flat_theta = tf.placeholder(tf.float32, [None], 'flat_theta')
|
|
|
| def setup(self, var_list, raw_loss, self_divergence,
|
| divergence=None):
|
| self.setup_placeholders()
|
|
|
| self.raw_loss = raw_loss
|
| self.divergence = divergence
|
| self.loss_flat_gradient = flatgrad(raw_loss, var_list)
|
| self.divergence_gradient = gradients(self_divergence, var_list)
|
|
|
| shapes = [var.shape for var in var_list]
|
| sizes = [var_size(var) for var in var_list]
|
|
|
| start = 0
|
| tangents = []
|
| for shape, size in zip(shapes, sizes):
|
| param = tf.reshape(self.flat_tangent[start:start + size], shape)
|
| tangents.append(param)
|
| start += size
|
| assert start == sum(sizes)
|
|
|
| self.grad_vector_product = sum(
|
| tf.reduce_sum(g * t) for (g, t) in zip(self.divergence_gradient, tangents))
|
| self.fisher_vector_product = flatgrad(self.grad_vector_product, var_list)
|
|
|
| self.flat_vars = get_flat(var_list)
|
| self.set_vars = set_from_flat(var_list, self.flat_theta)
|
|
|
| def optimize(self, sess, feed_dict):
|
| old_theta = sess.run(self.flat_vars)
|
| loss_flat_grad = sess.run(self.loss_flat_gradient,
|
| feed_dict=feed_dict)
|
|
|
| def calc_fisher_vector_product(tangent):
|
| feed_dict[self.flat_tangent] = tangent
|
| fvp = sess.run(self.fisher_vector_product,
|
| feed_dict=feed_dict)
|
| fvp += self.cg_damping * tangent
|
| return fvp
|
|
|
| step_dir = conjugate_gradient(calc_fisher_vector_product, -loss_flat_grad)
|
|
|
| shs = 0.5 * step_dir.dot(calc_fisher_vector_product(step_dir))
|
| lm = np.sqrt(shs / self.max_divergence)
|
| fullstep = step_dir / lm
|
| neggdotstepdir = -loss_flat_grad.dot(step_dir)
|
|
|
| def calc_loss(theta):
|
| sess.run(self.set_vars, feed_dict={self.flat_theta: theta})
|
| if self.divergence is None:
|
| return sess.run(self.raw_loss, feed_dict=feed_dict), True
|
| else:
|
| raw_loss, divergence = sess.run(
|
| [self.raw_loss, self.divergence], feed_dict=feed_dict)
|
| return raw_loss, divergence < self.max_divergence
|
|
|
|
|
| theta = linesearch(calc_loss, old_theta, fullstep, neggdotstepdir / lm)
|
| if self.divergence is not None:
|
| final_divergence = sess.run(self.divergence, feed_dict=feed_dict)
|
| else:
|
| final_divergence = None
|
|
|
|
|
| if final_divergence is None or final_divergence < self.max_divergence:
|
| sess.run(self.set_vars, feed_dict={self.flat_theta: theta})
|
| else:
|
| sess.run(self.set_vars, feed_dict={self.flat_theta: old_theta})
|
|
|
|
|
| def conjugate_gradient(f_Ax, b, cg_iters=10, residual_tol=1e-10):
|
| p = b.copy()
|
| r = b.copy()
|
| x = np.zeros_like(b)
|
| rdotr = r.dot(r)
|
| for i in xrange(cg_iters):
|
| z = f_Ax(p)
|
| v = rdotr / p.dot(z)
|
| x += v * p
|
| r -= v * z
|
| newrdotr = r.dot(r)
|
| mu = newrdotr / rdotr
|
| p = r + mu * p
|
| rdotr = newrdotr
|
| if rdotr < residual_tol:
|
| break
|
| return x
|
|
|
|
|
| def linesearch(f, x, fullstep, expected_improve_rate):
|
| accept_ratio = 0.1
|
| max_backtracks = 10
|
|
|
| fval, _ = f(x)
|
| for (_n_backtracks, stepfrac) in enumerate(.5 ** np.arange(max_backtracks)):
|
| xnew = x + stepfrac * fullstep
|
| newfval, valid = f(xnew)
|
| if not valid:
|
| continue
|
| actual_improve = fval - newfval
|
| expected_improve = expected_improve_rate * stepfrac
|
| ratio = actual_improve / expected_improve
|
| if ratio > accept_ratio and actual_improve > 0:
|
| return xnew
|
|
|
| return x
|
|
|