PIWM / src /envs /world_model_env.py
musictimer's picture
Fix bug 1
17fd5e3
from dataclasses import dataclass
from itertools import cycle
from pathlib import Path
from typing import Any, Dict, Generator, List, Optional, Tuple
import numpy as np
import torch
from torch import Tensor
from torch.distributions.categorical import Categorical
import torch.nn.functional as F
from ..coroutines import coroutine
from ..models.diffusion import Denoiser, DiffusionSampler, DiffusionSamplerConfig
from ..models.rew_end_model import RewEndModel
ResetOutput = Tuple[torch.FloatTensor, Dict[str, Any]]
StepOutput = Tuple[Tensor, Tensor, Tensor, Tensor, Dict[str, Any]]
InitialCondition = Tuple[Tensor, Tensor, Tuple[Tensor, Tensor]]
@dataclass
class WorldModelEnvConfig:
horizon: int
num_batches_to_preload: int
diffusion_sampler_next_obs: DiffusionSamplerConfig
diffusion_sampler_upsampling: Optional[DiffusionSamplerConfig] = None
class WorldModelEnv:
def __init__(
self,
denoiser: Denoiser,
upsampler: Optional[Denoiser],
rew_end_model: Optional[RewEndModel],
spawn_dir: Path,
num_envs: int,
seq_length: int,
cfg: WorldModelEnvConfig,
return_denoising_trajectory: bool = False,
) -> None:
assert num_envs == 1 # for csgo only
self.sampler_next_obs = DiffusionSampler(denoiser, cfg.diffusion_sampler_next_obs)
self.sampler_upsampling = None if upsampler is None else DiffusionSampler(upsampler, cfg.diffusion_sampler_upsampling)
self.rew_end_model = rew_end_model
self.horizon = cfg.horizon
self.return_denoising_trajectory = return_denoising_trajectory
self.num_envs = num_envs
self.generator_init = self.make_generator_init(spawn_dir, cfg.num_batches_to_preload)
self.n_skip_next_obs = seq_length - self.sampler_next_obs.denoiser.cfg.inner_model.num_steps_conditioning
self.n_skip_upsampling = None if upsampler is None else seq_length - self.sampler_upsampling.denoiser.cfg.inner_model.num_steps_conditioning
@property
def device(self) -> torch.device:
return self.sampler_next_obs.denoiser.device
@torch.no_grad()
def reset(self, **kwargs) -> ResetOutput:
obs, obs_full_res, act, next_act, (hx, cx) = self.generator_init.send(self.num_envs)
self.obs_buffer = obs
self.act_buffer = act
self.next_act = next_act[0]
self.obs_full_res_buffer = obs_full_res
self.ep_len = torch.zeros(self.num_envs, dtype=torch.long, device=obs.device)
self.hx_rew_end = hx
self.cx_rew_end = cx
obs_to_return = self.obs_buffer[:, -1] if self.sampler_upsampling is None else self.obs_full_res_buffer[:, -1]
return obs_to_return, {}
@torch.no_grad()
def step(self, act: torch.LongTensor) -> StepOutput:
self.act_buffer[:, -1] = act
next_obs, denoising_trajectory = self.predict_next_obs()
if self.sampler_upsampling is not None:
next_obs_full, denoising_trajectory_upsampling = self.upsample_next_obs(next_obs)
if self.rew_end_model is not None:
rew, end = self.predict_rew_end(next_obs.unsqueeze(1))
else:
rew = torch.zeros(next_obs.size(0), dtype=torch.float32, device=self.device)
end = torch.zeros(next_obs.size(0), dtype=torch.int64, device=self.device)
self.ep_len += 1
trunc = (self.ep_len >= self.horizon).long()
self.obs_buffer = self.obs_buffer.roll(-1, dims=1)
self.act_buffer = self.act_buffer.roll(-1, dims=1)
self.obs_buffer[:, -1] = next_obs
if self.sampler_upsampling is not None:
self.obs_full_res_buffer = self.obs_full_res_buffer.roll(-1, dims=1)
self.obs_full_res_buffer[:, -1] = next_obs_full
info = {}
if self.return_denoising_trajectory:
info["denoising_trajectory"] = torch.stack(denoising_trajectory, dim=1)
if self.sampler_upsampling is not None:
info["obs_low_res"] = next_obs
if self.return_denoising_trajectory:
info["denoising_trajectory_upsampling"] = torch.stack(denoising_trajectory_upsampling, dim=1)
obs_to_return = self.obs_buffer[:, -1] if self.sampler_upsampling is None else self.obs_full_res_buffer[:, -1]
return obs_to_return, rew, end, trunc, info
@torch.no_grad()
def predict_next_obs(self) -> Tuple[Tensor, List[Tensor]]:
return self.sampler_next_obs.sample(self.obs_buffer[:, self.n_skip_next_obs:], self.act_buffer[:, self.n_skip_next_obs:])
@torch.no_grad()
def upsample_next_obs(self, next_obs: Tensor) -> Tuple[Tensor, List[Tensor]]:
low_res = F.interpolate(next_obs, scale_factor=self.sampler_upsampling.denoiser.cfg.upsampling_factor, mode="bicubic").unsqueeze(1)
return self.sampler_upsampling.sample(torch.cat((self.obs_full_res_buffer[:, self.n_skip_upsampling:], low_res), dim=1), None)
@torch.no_grad()
def predict_rew_end(self, next_obs: Tensor) -> Tuple[Tensor, Tensor]:
logits_rew, logits_end, (self.hx_rew_end, self.cx_rew_end) = self.rew_end_model.predict_rew_end(
self.obs_buffer[:, -1:],
self.act_buffer[:, -1:],
next_obs,
(self.hx_rew_end, self.cx_rew_end),
)
rew = Categorical(logits=logits_rew).sample().squeeze(1) - 1.0 # in {-1, 0, 1}
end = Categorical(logits=logits_end).sample().squeeze(1)
return rew, end
@coroutine
def make_generator_init(
self,
spawn_dir: Path,
num_batches_to_preload: int,
) -> Generator[InitialCondition, None, None]:
num_dead = yield
spawn_dirs = cycle(sorted(list(spawn_dir.iterdir())))
while True:
# Preload on device and burnin rew/end model
obs_, obs_full_res_, act_, next_act_, hx_, cx_ = [], [], [], [], [], []
for _ in range(num_batches_to_preload):
d = next(spawn_dirs)
obs = torch.tensor(np.load(d / "low_res.npy"), device=self.device).div(255).mul(2).sub(1).unsqueeze(0)
obs_full_res = torch.tensor(np.load(d / "full_res.npy"), device=self.device).div(255).mul(2).sub(1).unsqueeze(0)
act = torch.tensor(np.load(d / "act.npy"), dtype=torch.long, device=self.device).unsqueeze(0)
next_act = torch.tensor(np.load(d / "next_act.npy"), dtype=torch.long, device=self.device).unsqueeze(0)
obs_.extend(list(obs))
obs_full_res_.extend(list(obs_full_res))
act_.extend(list(act))
next_act_.extend(list(next_act))
if self.rew_end_model is not None:
with torch.no_grad():
*_, (hx, cx) = self.rew_end_model.predict_rew_end(obs_[:, :-1], act[:, :-1], obs[:, 1:]) # Burn-in of rew/end model
assert hx.size(0) == cx.size(0) == 1
hx_.extend(list(hx[0]))
cx_.extend(list(cx[0]))
# Yield new initial conditions for dead envs
c = 0
while c + num_dead <= len(obs_):
obs = torch.stack(obs_[c : c + num_dead])
act = torch.stack(act_[c : c + num_dead])
next_act = next_act_[c : c + num_dead]
obs_full_res = torch.stack(obs_full_res_[c : c + num_dead]) if self.sampler_upsampling is not None else None
hx = torch.stack(hx_[c : c + num_dead]).unsqueeze(0) if self.rew_end_model is not None else None
cx = torch.stack(cx_[c : c + num_dead]).unsqueeze(0) if self.rew_end_model is not None else None
c += num_dead
num_dead = yield obs, obs_full_res, act, next_act, (hx, cx)