File size: 1,971 Bytes
5960497 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | from . import VecEnvWrapper
from baselines.bench.monitor import ResultsWriter
import numpy as np
import time
from collections import deque
class VecMonitor(VecEnvWrapper):
def __init__(self, venv, filename=None, keep_buf=0, info_keywords=()):
VecEnvWrapper.__init__(self, venv)
self.eprets = None
self.eplens = None
self.epcount = 0
self.tstart = time.time()
if filename:
self.results_writer = ResultsWriter(filename, header={'t_start': self.tstart},
extra_keys=info_keywords)
else:
self.results_writer = None
self.info_keywords = info_keywords
self.keep_buf = keep_buf
if self.keep_buf:
self.epret_buf = deque([], maxlen=keep_buf)
self.eplen_buf = deque([], maxlen=keep_buf)
def reset(self):
obs = self.venv.reset()
self.eprets = np.zeros(self.num_envs, 'f')
self.eplens = np.zeros(self.num_envs, 'i')
return obs
def step_wait(self):
obs, rews, dones, infos = self.venv.step_wait()
self.eprets += rews
self.eplens += 1
newinfos = list(infos[:])
for i in range(len(dones)):
if dones[i]:
info = infos[i].copy()
ret = self.eprets[i]
eplen = self.eplens[i]
epinfo = {'r': ret, 'l': eplen, 't': round(time.time() - self.tstart, 6)}
for k in self.info_keywords:
epinfo[k] = info[k]
info['episode'] = epinfo
if self.keep_buf:
self.epret_buf.append(ret)
self.eplen_buf.append(eplen)
self.epcount += 1
self.eprets[i] = 0
self.eplens[i] = 0
if self.results_writer:
self.results_writer.write_row(epinfo)
newinfos[i] = info
return obs, rews, dones, newinfos
|