|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| from environments.ant_maze_env import AntMazeEnv
|
| from environments.point_maze_env import PointMazeEnv
|
|
|
| import tensorflow as tf
|
| import gin.tf
|
| from tf_agents.environments import gym_wrapper
|
| from tf_agents.environments import tf_py_environment
|
|
|
|
|
| @gin.configurable
|
| def create_maze_env(env_name=None, top_down_view=False):
|
| n_bins = 0
|
| manual_collision = False
|
| if env_name.startswith('Ego'):
|
| n_bins = 8
|
| env_name = env_name[3:]
|
| if env_name.startswith('Ant'):
|
| cls = AntMazeEnv
|
| env_name = env_name[3:]
|
| maze_size_scaling = 8
|
| elif env_name.startswith('Point'):
|
| cls = PointMazeEnv
|
| manual_collision = True
|
| env_name = env_name[5:]
|
| maze_size_scaling = 4
|
| else:
|
| assert False, 'unknown env %s' % env_name
|
|
|
| maze_id = None
|
| observe_blocks = False
|
| put_spin_near_agent = False
|
| if env_name == 'Maze':
|
| maze_id = 'Maze'
|
| elif env_name == 'Push':
|
| maze_id = 'Push'
|
| elif env_name == 'Fall':
|
| maze_id = 'Fall'
|
| elif env_name == 'Block':
|
| maze_id = 'Block'
|
| put_spin_near_agent = True
|
| observe_blocks = True
|
| elif env_name == 'BlockMaze':
|
| maze_id = 'BlockMaze'
|
| put_spin_near_agent = True
|
| observe_blocks = True
|
| else:
|
| raise ValueError('Unknown maze environment %s' % env_name)
|
|
|
| gym_mujoco_kwargs = {
|
| 'maze_id': maze_id,
|
| 'n_bins': n_bins,
|
| 'observe_blocks': observe_blocks,
|
| 'put_spin_near_agent': put_spin_near_agent,
|
| 'top_down_view': top_down_view,
|
| 'manual_collision': manual_collision,
|
| 'maze_size_scaling': maze_size_scaling
|
| }
|
| gym_env = cls(**gym_mujoco_kwargs)
|
| gym_env.reset()
|
| wrapped_env = gym_wrapper.GymWrapper(gym_env)
|
| return wrapped_env
|
|
|
|
|
| class TFPyEnvironment(tf_py_environment.TFPyEnvironment):
|
|
|
| def __init__(self, *args, **kwargs):
|
| super(TFPyEnvironment, self).__init__(*args, **kwargs)
|
|
|
| def start_collect(self):
|
| pass
|
|
|
| def current_obs(self):
|
| time_step = self.current_time_step()
|
| return time_step.observation[0]
|
|
|
| def step(self, actions):
|
| actions = tf.expand_dims(actions, 0)
|
| next_step = super(TFPyEnvironment, self).step(actions)
|
| return next_step.is_last()[0], next_step.reward[0], next_step.discount[0]
|
|
|
| def reset(self):
|
| return super(TFPyEnvironment, self).reset()
|
|
|