| from . import VecEnvWrapper | |
| from baselines.bench.monitor import ResultsWriter | |
| import numpy as np | |
| import time | |
| from collections import deque | |
| class VecMonitor(VecEnvWrapper): | |
| def __init__(self, venv, filename=None, keep_buf=0, info_keywords=()): | |
| VecEnvWrapper.__init__(self, venv) | |
| self.eprets = None | |
| self.eplens = None | |
| self.epcount = 0 | |
| self.tstart = time.time() | |
| if filename: | |
| self.results_writer = ResultsWriter(filename, header={'t_start': self.tstart}, | |
| extra_keys=info_keywords) | |
| else: | |
| self.results_writer = None | |
| self.info_keywords = info_keywords | |
| self.keep_buf = keep_buf | |
| if self.keep_buf: | |
| self.epret_buf = deque([], maxlen=keep_buf) | |
| self.eplen_buf = deque([], maxlen=keep_buf) | |
| def reset(self): | |
| obs = self.venv.reset() | |
| self.eprets = np.zeros(self.num_envs, 'f') | |
| self.eplens = np.zeros(self.num_envs, 'i') | |
| return obs | |
| def step_wait(self): | |
| obs, rews, dones, infos = self.venv.step_wait() | |
| self.eprets += rews | |
| self.eplens += 1 | |
| newinfos = list(infos[:]) | |
| for i in range(len(dones)): | |
| if dones[i]: | |
| info = infos[i].copy() | |
| ret = self.eprets[i] | |
| eplen = self.eplens[i] | |
| epinfo = {'r': ret, 'l': eplen, 't': round(time.time() - self.tstart, 6)} | |
| for k in self.info_keywords: | |
| epinfo[k] = info[k] | |
| info['episode'] = epinfo | |
| if self.keep_buf: | |
| self.epret_buf.append(ret) | |
| self.eplen_buf.append(eplen) | |
| self.epcount += 1 | |
| self.eprets[i] = 0 | |
| self.eplens[i] = 0 | |
| if self.results_writer: | |
| self.results_writer.write_row(epinfo) | |
| newinfos[i] = info | |
| return obs, rews, dones, newinfos | |