| import numpy as np | |
| import tensorflow as tf | |
| tf.compat.v1.disable_eager_execution() | |
| from gym.spaces import Discrete, Box, MultiDiscrete | |
| def observation_placeholder(ob_space, batch_size=None, name='Ob'): | |
| ''' | |
| Create placeholder to feed observations into of the size appropriate to the observation space | |
| Parameters: | |
| ---------- | |
| ob_space: gym.Space observation space | |
| batch_size: int size of the batch to be fed into input. Can be left None in most cases. | |
| name: str name of the placeholder | |
| Returns: | |
| ------- | |
| tensorflow placeholder tensor | |
| ''' | |
| assert isinstance(ob_space, Discrete) or isinstance(ob_space, Box) or isinstance(ob_space, MultiDiscrete), \ | |
| 'Can only deal with Discrete and Box observation spaces for now' | |
| dtype = ob_space.dtype | |
| if dtype == np.int8: | |
| dtype = np.uint8 | |
| return tf.compat.v1.placeholder(shape=(batch_size,) + ob_space.shape, dtype=dtype, name=name) | |
| def observation_input(ob_space, batch_size=None, name='Ob'): | |
| ''' | |
| Create placeholder to feed observations into of the size appropriate to the observation space, and add input | |
| encoder of the appropriate type. | |
| ''' | |
| placeholder = observation_placeholder(ob_space, batch_size, name) | |
| return placeholder, encode_observation(ob_space, placeholder) | |
| def encode_observation(ob_space, placeholder): | |
| ''' | |
| Encode input in the way that is appropriate to the observation space | |
| Parameters: | |
| ---------- | |
| ob_space: gym.Space observation space | |
| placeholder: tf.placeholder observation input placeholder | |
| ''' | |
| if isinstance(ob_space, Discrete): | |
| return tf.cast(tf.one_hot(placeholder, ob_space.n), dtype=tf.float32) | |
| elif isinstance(ob_space, Box): | |
| return tf.cast(placeholder, dtype=tf.float32) | |
| elif isinstance(ob_space, MultiDiscrete): | |
| placeholder = tf.cast(placeholder, tf.int32) | |
| one_hots = [tf.cast(tf.one_hot(placeholder[..., i], ob_space.nvec[i]), dtype=tf.float32) for i in range(placeholder.shape[-1])] | |
| return tf.concat(one_hots, axis=-1) | |
| else: | |
| raise NotImplementedError | |