File size: 5,069 Bytes
5960497 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | import multiprocessing as mp
import numpy as np
from .vec_env import VecEnv, CloudpickleWrapper, clear_mpi_env_vars
def worker(remote, parent_remote, env_fn_wrappers):
def step_env(env, action):
ob, reward, done, info = env.step(action)
if done:
ob = env.reset()
return ob, reward, done, info
parent_remote.close()
envs = [env_fn_wrapper() for env_fn_wrapper in env_fn_wrappers.x]
try:
while True:
cmd, data = remote.recv()
if cmd == 'step':
remote.send([step_env(env, action) for env, action in zip(envs, data)])
elif cmd == 'reset':
remote.send([env.reset() for env in envs])
elif cmd == 'render':
remote.send([env.render(mode='rgb_array') for env in envs])
elif cmd == 'close':
remote.close()
break
elif cmd == 'get_spaces_spec':
remote.send(CloudpickleWrapper((envs[0].observation_space, envs[0].action_space, envs[0].spec)))
else:
raise NotImplementedError
except KeyboardInterrupt:
print('SubprocVecEnv worker: got KeyboardInterrupt')
finally:
for env in envs:
env.close()
class SubprocVecEnv(VecEnv):
"""
VecEnv that runs multiple environments in parallel in subproceses and communicates with them via pipes.
Recommended to use when num_envs > 1 and step() can be a bottleneck.
"""
def __init__(self, env_fns, spaces=None, context='spawn', in_series=1):
"""
Arguments:
env_fns: iterable of callables - functions that create environments to run in subprocesses. Need to be cloud-pickleable
in_series: number of environments to run in series in a single process
(e.g. when len(env_fns) == 12 and in_series == 3, it will run 4 processes, each running 3 envs in series)
"""
self.waiting = False
self.closed = False
self.in_series = in_series
nenvs = len(env_fns)
assert nenvs % in_series == 0, "Number of envs must be divisible by number of envs to run in series"
self.nremotes = nenvs // in_series
env_fns = np.array_split(env_fns, self.nremotes)
ctx = mp.get_context(context)
self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(self.nremotes)])
self.ps = [ctx.Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
for p in self.ps:
p.daemon = True # if the main process crashes, we should not cause things to hang
with clear_mpi_env_vars():
p.start()
for remote in self.work_remotes:
remote.close()
self.remotes[0].send(('get_spaces_spec', None))
observation_space, action_space, self.spec = self.remotes[0].recv().x
self.viewer = None
VecEnv.__init__(self, nenvs, observation_space, action_space)
def step_async(self, actions):
self._assert_not_closed()
actions = np.array_split(actions, self.nremotes)
for remote, action in zip(self.remotes, actions):
remote.send(('step', action))
self.waiting = True
def step_wait(self):
self._assert_not_closed()
results = [remote.recv() for remote in self.remotes]
results = _flatten_list(results)
self.waiting = False
obs, rews, dones, infos = zip(*results)
return _flatten_obs(obs), np.stack(rews), np.stack(dones), infos
def reset(self):
self._assert_not_closed()
for remote in self.remotes:
remote.send(('reset', None))
obs = [remote.recv() for remote in self.remotes]
obs = _flatten_list(obs)
return _flatten_obs(obs)
def close_extras(self):
self.closed = True
if self.waiting:
for remote in self.remotes:
remote.recv()
for remote in self.remotes:
remote.send(('close', None))
for p in self.ps:
p.join()
def get_images(self):
self._assert_not_closed()
for pipe in self.remotes:
pipe.send(('render', None))
imgs = [pipe.recv() for pipe in self.remotes]
imgs = _flatten_list(imgs)
return imgs
def _assert_not_closed(self):
assert not self.closed, "Trying to operate on a SubprocVecEnv after calling close()"
def __del__(self):
if not self.closed:
self.close()
def _flatten_obs(obs):
assert isinstance(obs, (list, tuple))
assert len(obs) > 0
if isinstance(obs[0], dict):
keys = obs[0].keys()
return {k: np.stack([o[k] for o in obs]) for k in keys}
else:
return np.stack(obs)
def _flatten_list(l):
assert isinstance(l, (list, tuple))
assert len(l) > 0
assert all([len(l_) > 0 for l_ in l])
return [l__ for l_ in l for l__ in l_]
|