| | from baselines.common.input import observation_input |
| | from baselines.common.tf_util import adjust_shape |
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | class TfInput(object): |
| | def __init__(self, name="(unnamed)"): |
| | """Generalized Tensorflow placeholder. The main differences are: |
| | - possibly uses multiple placeholders internally and returns multiple values |
| | - can apply light postprocessing to the value feed to placeholder. |
| | """ |
| | self.name = name |
| |
|
| | def get(self): |
| | """Return the tf variable(s) representing the possibly postprocessed value |
| | of placeholder(s). |
| | """ |
| | raise NotImplementedError |
| |
|
| | def make_feed_dict(self, data): |
| | """Given data input it to the placeholder(s).""" |
| | raise NotImplementedError |
| |
|
| |
|
| | class PlaceholderTfInput(TfInput): |
| | def __init__(self, placeholder): |
| | """Wrapper for regular tensorflow placeholder.""" |
| | super().__init__(placeholder.name) |
| | self._placeholder = placeholder |
| |
|
| | def get(self): |
| | return self._placeholder |
| |
|
| | def make_feed_dict(self, data): |
| | return {self._placeholder: adjust_shape(self._placeholder, data)} |
| |
|
| |
|
| | class ObservationInput(PlaceholderTfInput): |
| | def __init__(self, observation_space, name=None): |
| | """Creates an input placeholder tailored to a specific observation space |
| | |
| | Parameters |
| | ---------- |
| | |
| | observation_space: |
| | observation space of the environment. Should be one of the gym.spaces types |
| | name: str |
| | tensorflow name of the underlying placeholder |
| | """ |
| | inpt, self.processed_inpt = observation_input(observation_space, name=name) |
| | super().__init__(inpt) |
| |
|
| | def get(self): |
| | return self.processed_inpt |
| |
|
| |
|
| |
|