Spaces:
Sleeping
Sleeping
| import random | |
| from typing import Generator, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| from torch.distributions.categorical import Categorical | |
| from . import coroutine | |
| from ..envs import TorchEnv, WorldModelEnv | |
| def make_env_loop( | |
| env: Union[TorchEnv, WorldModelEnv], model: nn.Module, epsilon: float = 0.0 | |
| ) -> Generator[Tuple[torch.Tensor, ...], int, None]: | |
| num_steps = yield | |
| hx = torch.zeros(env.num_envs, model.lstm_dim, device=model.device) | |
| cx = torch.zeros(env.num_envs, model.lstm_dim, device=model.device) | |
| seed = random.randint(0, 2**31 - 1) | |
| obs, _ = env.reset(seed=[seed + i for i in range(env.num_envs)]) | |
| while True: | |
| hx, cx = hx.detach(), cx.detach() | |
| all_ = [] | |
| infos = [] | |
| n = 0 | |
| while n < num_steps: | |
| logits_act, val, (hx, cx) = model.predict_act_value(obs, (hx, cx)) | |
| act = Categorical(logits=logits_act).sample() | |
| if random.random() < epsilon: | |
| act = torch.randint(low=0, high=env.num_actions, size=(obs.size(0),), device=obs.device) | |
| next_obs, rew, end, trunc, info = env.step(act) | |
| if n > 0: | |
| val_bootstrap = val.detach().clone() | |
| if dead.any(): | |
| val_bootstrap[dead] = val_final_obs | |
| all_[-1][-1] = val_bootstrap | |
| dead = torch.logical_or(end, trunc) | |
| if dead.any(): | |
| with torch.no_grad(): | |
| _, val_final_obs, _ = model.predict_act_value(info["final_observation"], (hx[dead], cx[dead])) | |
| reset_gate = 1 - dead.float().unsqueeze(1) | |
| hx = hx * reset_gate | |
| cx = cx * reset_gate | |
| if "burnin_obs" in info: | |
| burnin_obs = info["burnin_obs"] | |
| for i in range(burnin_obs.size(1)): | |
| _, _, (hx[dead], cx[dead]) = model.predict_act_value(burnin_obs[:, i], (hx[dead], cx[dead])) | |
| all_.append([obs, act, rew, end, trunc, logits_act, val, None]) | |
| infos.append(info) | |
| obs = next_obs | |
| n += 1 | |
| with torch.no_grad(): | |
| _, val_bootstrap, _ = model.predict_act_value(next_obs, (hx, cx)) # do not update hx/cx | |
| if dead.any(): | |
| val_bootstrap[dead] = val_final_obs | |
| all_[-1][-1] = val_bootstrap | |
| all_obs, act, rew, end, trunc, logits_act, val, val_bootstrap = (torch.stack(x, dim=1) for x in zip(*all_)) | |
| num_steps = yield all_obs, act, rew, end, trunc, logits_act, val, val_bootstrap, infos | |