File size: 670 Bytes
5960497 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 | import numpy as np
from abc import ABC, abstractmethod
class AbstractEnvRunner(ABC):
def __init__(self, *, env, model, nsteps):
self.env = env
self.model = model
self.nenv = nenv = env.num_envs if hasattr(env, 'num_envs') else 1
self.batch_ob_shape = (nenv*nsteps,) + env.observation_space.shape
self.obs = np.zeros((nenv,) + env.observation_space.shape, dtype=env.observation_space.dtype.name)
self.obs[:] = env.reset()
self.nsteps = nsteps
self.states = model.initial_state
self.dones = [False for _ in range(nenv)]
@abstractmethod
def run(self):
raise NotImplementedError
|