| | import numpy as np |
| | from abc import ABC, abstractmethod |
| |
|
| | class AbstractEnvRunner(ABC): |
| | def __init__(self, *, env, model, nsteps): |
| | self.env = env |
| | self.model = model |
| | self.nenv = nenv = env.num_envs if hasattr(env, 'num_envs') else 1 |
| | self.batch_ob_shape = (nenv*nsteps,) + env.observation_space.shape |
| | self.obs = np.zeros((nenv,) + env.observation_space.shape, dtype=env.observation_space.dtype.name) |
| | self.obs[:] = env.reset() |
| | self.nsteps = nsteps |
| | self.states = model.initial_state |
| | self.dones = [False for _ in range(nenv)] |
| |
|
| | @abstractmethod |
| | def run(self): |
| | raise NotImplementedError |
| |
|
| |
|