| import tensorflow as tf |
| import numpy as np |
| import baselines.common.tf_util as U |
| from baselines.a2c.utils import fc |
| from tensorflow.python.ops import math_ops |
|
|
| class Pd(object): |
| """ |
| A particular probability distribution |
| """ |
| def flatparam(self): |
| raise NotImplementedError |
| def mode(self): |
| raise NotImplementedError |
| def neglogp(self, x): |
| |
| raise NotImplementedError |
| def kl(self, other): |
| raise NotImplementedError |
| def entropy(self): |
| raise NotImplementedError |
| def sample(self): |
| raise NotImplementedError |
| def logp(self, x): |
| return - self.neglogp(x) |
| def get_shape(self): |
| return self.flatparam().shape |
| @property |
| def shape(self): |
| return self.get_shape() |
| def __getitem__(self, idx): |
| return self.__class__(self.flatparam()[idx]) |
|
|
| class PdType(object): |
| """ |
| Parametrized family of probability distributions |
| """ |
| def pdclass(self): |
| raise NotImplementedError |
| def pdfromflat(self, flat): |
| return self.pdclass()(flat) |
| def pdfromlatent(self, latent_vector, init_scale, init_bias): |
| raise NotImplementedError |
| def param_shape(self): |
| raise NotImplementedError |
| def sample_shape(self): |
| raise NotImplementedError |
| def sample_dtype(self): |
| raise NotImplementedError |
|
|
| def param_placeholder(self, prepend_shape, name=None): |
| return tf.compat.v1.placeholder(dtype=tf.float32, shape=prepend_shape+self.param_shape(), name=name) |
| def sample_placeholder(self, prepend_shape, name=None): |
| return tf.compat.v1.placeholder(dtype=self.sample_dtype(), shape=prepend_shape+self.sample_shape(), name=name) |
|
|
| def __eq__(self, other): |
| return (type(self) == type(other)) and (self.__dict__ == other.__dict__) |
|
|
| class CategoricalPdType(PdType): |
| def __init__(self, ncat): |
| self.ncat = ncat |
| def pdclass(self): |
| return CategoricalPd |
| def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0): |
| pdparam = _matching_fc(latent_vector, 'pi', self.ncat, init_scale=init_scale, init_bias=init_bias) |
| return self.pdfromflat(pdparam), pdparam |
|
|
| def param_shape(self): |
| return [self.ncat] |
| def sample_shape(self): |
| return [] |
| def sample_dtype(self): |
| return tf.int32 |
|
|
|
|
| class MultiCategoricalPdType(PdType): |
| def __init__(self, nvec): |
| self.ncats = nvec.astype('int32') |
| assert (self.ncats > 0).all() |
| def pdclass(self): |
| return MultiCategoricalPd |
| def pdfromflat(self, flat): |
| return MultiCategoricalPd(self.ncats, flat) |
|
|
| def pdfromlatent(self, latent, init_scale=1.0, init_bias=0.0): |
| pdparam = _matching_fc(latent, 'pi', self.ncats.sum(), init_scale=init_scale, init_bias=init_bias) |
| return self.pdfromflat(pdparam), pdparam |
|
|
| def param_shape(self): |
| return [sum(self.ncats)] |
| def sample_shape(self): |
| return [len(self.ncats)] |
| def sample_dtype(self): |
| return tf.int32 |
|
|
| class DiagGaussianPdType(PdType): |
| def __init__(self, size): |
| self.size = size |
| def pdclass(self): |
| return DiagGaussianPd |
|
|
| def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0): |
| mean = _matching_fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias) |
| logstd = tf.compat.v1.get_variable(name='pi/logstd', shape=[1, self.size], initializer=tf.compat.v1.zeros_initializer()) |
| pdparam = tf.concat([mean, mean * 0.0 + logstd], axis=1) |
| return self.pdfromflat(pdparam), mean |
|
|
| def param_shape(self): |
| return [2*self.size] |
| def sample_shape(self): |
| return [self.size] |
| def sample_dtype(self): |
| return tf.float32 |
|
|
| class BernoulliPdType(PdType): |
| def __init__(self, size): |
| self.size = size |
| def pdclass(self): |
| return BernoulliPd |
| def param_shape(self): |
| return [self.size] |
| def sample_shape(self): |
| return [self.size] |
| def sample_dtype(self): |
| return tf.int32 |
| def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0): |
| pdparam = _matching_fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias) |
| return self.pdfromflat(pdparam), pdparam |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| class CategoricalPd(Pd): |
| def __init__(self, logits): |
| self.logits = logits |
| def flatparam(self): |
| return self.logits |
| def mode(self): |
| return tf.argmax(input=self.logits, axis=-1) |
|
|
| @property |
| def mean(self): |
| return tf.nn.softmax(self.logits) |
| def neglogp(self, x): |
| |
| |
| |
| if x.dtype in {tf.uint8, tf.int32, tf.int64}: |
| |
| x_shape_list = x.shape.as_list() |
| logits_shape_list = self.logits.get_shape().as_list()[:-1] |
| for xs, ls in zip(x_shape_list, logits_shape_list): |
| if xs is not None and ls is not None: |
| assert xs == ls, 'shape mismatch: {} in x vs {} in logits'.format(xs, ls) |
|
|
| x = tf.one_hot(x, self.logits.get_shape().as_list()[-1]) |
| else: |
| |
| assert x.shape.as_list() == self.logits.shape.as_list() |
|
|
| return tf.nn.softmax_cross_entropy_with_logits( |
| logits=self.logits, |
| labels=x) |
| def kl(self, other): |
| a0 = self.logits - tf.reduce_max(input_tensor=self.logits, axis=-1, keepdims=True) |
| a1 = other.logits - tf.reduce_max(input_tensor=other.logits, axis=-1, keepdims=True) |
| ea0 = tf.exp(a0) |
| ea1 = tf.exp(a1) |
| z0 = tf.reduce_sum(input_tensor=ea0, axis=-1, keepdims=True) |
| z1 = tf.reduce_sum(input_tensor=ea1, axis=-1, keepdims=True) |
| p0 = ea0 / z0 |
| return tf.reduce_sum(input_tensor=p0 * (a0 - tf.math.log(z0) - a1 + tf.math.log(z1)), axis=-1) |
| def entropy(self): |
| a0 = self.logits - tf.reduce_max(input_tensor=self.logits, axis=-1, keepdims=True) |
| ea0 = tf.exp(a0) |
| z0 = tf.reduce_sum(input_tensor=ea0, axis=-1, keepdims=True) |
| p0 = ea0 / z0 |
| return tf.reduce_sum(input_tensor=p0 * (tf.math.log(z0) - a0), axis=-1) |
| def sample(self): |
| u = tf.random.uniform(tf.shape(input=self.logits), dtype=self.logits.dtype) |
| return tf.argmax(input=self.logits - tf.math.log(-tf.math.log(u)), axis=-1) |
| @classmethod |
| def fromflat(cls, flat): |
| return cls(flat) |
|
|
| class MultiCategoricalPd(Pd): |
| def __init__(self, nvec, flat): |
| self.flat = flat |
| self.categoricals = list(map(CategoricalPd, |
| tf.split(flat, np.array(nvec, dtype=np.int32), axis=-1))) |
| def flatparam(self): |
| return self.flat |
| def mode(self): |
| return tf.cast(tf.stack([p.mode() for p in self.categoricals], axis=-1), tf.int32) |
| def neglogp(self, x): |
| return tf.add_n([p.neglogp(px) for p, px in zip(self.categoricals, tf.unstack(x, axis=-1))]) |
| def kl(self, other): |
| return tf.add_n([p.kl(q) for p, q in zip(self.categoricals, other.categoricals)]) |
| def entropy(self): |
| return tf.add_n([p.entropy() for p in self.categoricals]) |
| def sample(self): |
| return tf.cast(tf.stack([p.sample() for p in self.categoricals], axis=-1), tf.int32) |
| @classmethod |
| def fromflat(cls, flat): |
| raise NotImplementedError |
|
|
| class DiagGaussianPd(Pd): |
| def __init__(self, flat): |
| self.flat = flat |
| mean, logstd = tf.split(axis=len(flat.shape)-1, num_or_size_splits=2, value=flat) |
| self.mean = mean |
| self.logstd = logstd |
| self.std = tf.exp(logstd) |
| def flatparam(self): |
| return self.flat |
| def mode(self): |
| return self.mean |
| def neglogp(self, x): |
| return 0.5 * tf.reduce_sum(input_tensor=tf.square((x - self.mean) / self.std), axis=-1) \ |
| + 0.5 * np.log(2.0 * np.pi) * tf.cast(tf.shape(input=x)[-1], dtype=tf.float32) \ |
| + tf.reduce_sum(input_tensor=self.logstd, axis=-1) |
| def kl(self, other): |
| assert isinstance(other, DiagGaussianPd) |
| return tf.reduce_sum(input_tensor=other.logstd - self.logstd + (tf.square(self.std) + tf.square(self.mean - other.mean)) / (2.0 * tf.square(other.std)) - 0.5, axis=-1) |
| def entropy(self): |
| return tf.reduce_sum(input_tensor=self.logstd + .5 * np.log(2.0 * np.pi * np.e), axis=-1) |
| def sample(self): |
| return self.mean + self.std * tf.random.normal(tf.shape(input=self.mean)) |
| @classmethod |
| def fromflat(cls, flat): |
| return cls(flat) |
|
|
|
|
| class BernoulliPd(Pd): |
| def __init__(self, logits): |
| self.logits = logits |
| self.ps = tf.sigmoid(logits) |
| def flatparam(self): |
| return self.logits |
| @property |
| def mean(self): |
| return self.ps |
| def mode(self): |
| return tf.round(self.ps) |
| def neglogp(self, x): |
| return tf.reduce_sum(input_tensor=tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=tf.cast(x, dtype=tf.float32)), axis=-1) |
| def kl(self, other): |
| return tf.reduce_sum(input_tensor=tf.nn.sigmoid_cross_entropy_with_logits(logits=other.logits, labels=self.ps), axis=-1) - tf.reduce_sum(input_tensor=tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=-1) |
| def entropy(self): |
| return tf.reduce_sum(input_tensor=tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=-1) |
| def sample(self): |
| u = tf.random.uniform(tf.shape(input=self.ps)) |
| return tf.cast(math_ops.less(u, self.ps), dtype=tf.float32) |
| @classmethod |
| def fromflat(cls, flat): |
| return cls(flat) |
|
|
| def make_pdtype(ac_space): |
| from gym import spaces |
| if isinstance(ac_space, spaces.Box): |
| assert len(ac_space.shape) == 1 |
| return DiagGaussianPdType(ac_space.shape[0]) |
| elif isinstance(ac_space, spaces.Discrete): |
| return CategoricalPdType(ac_space.n) |
| elif isinstance(ac_space, spaces.MultiDiscrete): |
| return MultiCategoricalPdType(ac_space.nvec) |
| elif isinstance(ac_space, spaces.MultiBinary): |
| return BernoulliPdType(ac_space.n) |
| else: |
| raise NotImplementedError |
|
|
| def shape_el(v, i): |
| maybe = v.get_shape()[i] |
| if maybe is not None: |
| return maybe |
| else: |
| return tf.shape(input=v)[i] |
|
|
| @U.in_session |
| def test_probtypes(): |
| np.random.seed(0) |
|
|
| pdparam_diag_gauss = np.array([-.2, .3, .4, -.5, .1, -.5, .1, 0.8]) |
| diag_gauss = DiagGaussianPdType(pdparam_diag_gauss.size // 2) |
| validate_probtype(diag_gauss, pdparam_diag_gauss) |
|
|
| pdparam_categorical = np.array([-.2, .3, .5]) |
| categorical = CategoricalPdType(pdparam_categorical.size) |
| validate_probtype(categorical, pdparam_categorical) |
|
|
| nvec = [1,2,3] |
| pdparam_multicategorical = np.array([-.2, .3, .5, .1, 1, -.1]) |
| multicategorical = MultiCategoricalPdType(nvec) |
| validate_probtype(multicategorical, pdparam_multicategorical) |
|
|
| pdparam_bernoulli = np.array([-.2, .3, .5]) |
| bernoulli = BernoulliPdType(pdparam_bernoulli.size) |
| validate_probtype(bernoulli, pdparam_bernoulli) |
|
|
|
|
| def validate_probtype(probtype, pdparam): |
| N = 100000 |
| |
| Mval = np.repeat(pdparam[None, :], N, axis=0) |
| M = probtype.param_placeholder([N]) |
| X = probtype.sample_placeholder([N]) |
| pd = probtype.pdfromflat(M) |
| calcloglik = U.function([X, M], pd.logp(X)) |
| calcent = U.function([M], pd.entropy()) |
| Xval = tf.compat.v1.get_default_session().run(pd.sample(), feed_dict={M:Mval}) |
| logliks = calcloglik(Xval, Mval) |
| entval_ll = - logliks.mean() |
| entval_ll_stderr = logliks.std() / np.sqrt(N) |
| entval = calcent(Mval).mean() |
| assert np.abs(entval - entval_ll) < 3 * entval_ll_stderr |
|
|
| |
| M2 = probtype.param_placeholder([N]) |
| pd2 = probtype.pdfromflat(M2) |
| q = pdparam + np.random.randn(pdparam.size) * 0.1 |
| Mval2 = np.repeat(q[None, :], N, axis=0) |
| calckl = U.function([M, M2], pd.kl(pd2)) |
| klval = calckl(Mval, Mval2).mean() |
| logliks = calcloglik(Xval, Mval2) |
| klval_ll = - entval - logliks.mean() |
| klval_ll_stderr = logliks.std() / np.sqrt(N) |
| assert np.abs(klval - klval_ll) < 3 * klval_ll_stderr |
| print('ok on', probtype, pdparam) |
|
|
|
|
| def _matching_fc(tensor, name, size, init_scale, init_bias): |
| if tensor.shape[-1] == size: |
| return tensor |
| else: |
| return fc(tensor, name, size, init_scale=init_scale, init_bias=init_bias) |
|
|