Spaces:
Sleeping
Sleeping
File size: 2,592 Bytes
c64c726 17fd5e3 c64c726 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import random
from typing import Generator, Tuple, Union
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
from . import coroutine
from ..envs import TorchEnv, WorldModelEnv
@coroutine
def make_env_loop(
env: Union[TorchEnv, WorldModelEnv], model: nn.Module, epsilon: float = 0.0
) -> Generator[Tuple[torch.Tensor, ...], int, None]:
num_steps = yield
hx = torch.zeros(env.num_envs, model.lstm_dim, device=model.device)
cx = torch.zeros(env.num_envs, model.lstm_dim, device=model.device)
seed = random.randint(0, 2**31 - 1)
obs, _ = env.reset(seed=[seed + i for i in range(env.num_envs)])
while True:
hx, cx = hx.detach(), cx.detach()
all_ = []
infos = []
n = 0
while n < num_steps:
logits_act, val, (hx, cx) = model.predict_act_value(obs, (hx, cx))
act = Categorical(logits=logits_act).sample()
if random.random() < epsilon:
act = torch.randint(low=0, high=env.num_actions, size=(obs.size(0),), device=obs.device)
next_obs, rew, end, trunc, info = env.step(act)
if n > 0:
val_bootstrap = val.detach().clone()
if dead.any():
val_bootstrap[dead] = val_final_obs
all_[-1][-1] = val_bootstrap
dead = torch.logical_or(end, trunc)
if dead.any():
with torch.no_grad():
_, val_final_obs, _ = model.predict_act_value(info["final_observation"], (hx[dead], cx[dead]))
reset_gate = 1 - dead.float().unsqueeze(1)
hx = hx * reset_gate
cx = cx * reset_gate
if "burnin_obs" in info:
burnin_obs = info["burnin_obs"]
for i in range(burnin_obs.size(1)):
_, _, (hx[dead], cx[dead]) = model.predict_act_value(burnin_obs[:, i], (hx[dead], cx[dead]))
all_.append([obs, act, rew, end, trunc, logits_act, val, None])
infos.append(info)
obs = next_obs
n += 1
with torch.no_grad():
_, val_bootstrap, _ = model.predict_act_value(next_obs, (hx, cx)) # do not update hx/cx
if dead.any():
val_bootstrap[dead] = val_final_obs
all_[-1][-1] = val_bootstrap
all_obs, act, rew, end, trunc, logits_act, val, val_bootstrap = (torch.stack(x, dim=1) for x in zip(*all_))
num_steps = yield all_obs, act, rew, end, trunc, logits_act, val, val_bootstrap, infos
|