import numpy as np import tensorflow as tf tf.compat.v1.disable_eager_execution() from gym.spaces import Discrete, Box, MultiDiscrete def observation_placeholder(ob_space, batch_size=None, name='Ob'): ''' Create placeholder to feed observations into of the size appropriate to the observation space Parameters: ---------- ob_space: gym.Space observation space batch_size: int size of the batch to be fed into input. Can be left None in most cases. name: str name of the placeholder Returns: ------- tensorflow placeholder tensor ''' assert isinstance(ob_space, Discrete) or isinstance(ob_space, Box) or isinstance(ob_space, MultiDiscrete), \ 'Can only deal with Discrete and Box observation spaces for now' dtype = ob_space.dtype if dtype == np.int8: dtype = np.uint8 return tf.compat.v1.placeholder(shape=(batch_size,) + ob_space.shape, dtype=dtype, name=name) def observation_input(ob_space, batch_size=None, name='Ob'): ''' Create placeholder to feed observations into of the size appropriate to the observation space, and add input encoder of the appropriate type. ''' placeholder = observation_placeholder(ob_space, batch_size, name) return placeholder, encode_observation(ob_space, placeholder) def encode_observation(ob_space, placeholder): ''' Encode input in the way that is appropriate to the observation space Parameters: ---------- ob_space: gym.Space observation space placeholder: tf.placeholder observation input placeholder ''' if isinstance(ob_space, Discrete): return tf.cast(tf.one_hot(placeholder, ob_space.n), dtype=tf.float32) elif isinstance(ob_space, Box): return tf.cast(placeholder, dtype=tf.float32) elif isinstance(ob_space, MultiDiscrete): placeholder = tf.cast(placeholder, tf.int32) one_hots = [tf.cast(tf.one_hot(placeholder[..., i], ob_space.nvec[i]), dtype=tf.float32) for i in range(placeholder.shape[-1])] return tf.concat(one_hots, axis=-1) else: raise NotImplementedError