Spaces:
Sleeping
Sleeping
| from dataclasses import dataclass | |
| from itertools import cycle | |
| from pathlib import Path | |
| from typing import Any, Dict, Generator, List, Optional, Tuple | |
| import numpy as np | |
| import torch | |
| from torch import Tensor | |
| from torch.distributions.categorical import Categorical | |
| import torch.nn.functional as F | |
| from ..coroutines import coroutine | |
| from ..models.diffusion import Denoiser, DiffusionSampler, DiffusionSamplerConfig | |
| from ..models.rew_end_model import RewEndModel | |
| ResetOutput = Tuple[torch.FloatTensor, Dict[str, Any]] | |
| StepOutput = Tuple[Tensor, Tensor, Tensor, Tensor, Dict[str, Any]] | |
| InitialCondition = Tuple[Tensor, Tensor, Tuple[Tensor, Tensor]] | |
| class WorldModelEnvConfig: | |
| horizon: int | |
| num_batches_to_preload: int | |
| diffusion_sampler_next_obs: DiffusionSamplerConfig | |
| diffusion_sampler_upsampling: Optional[DiffusionSamplerConfig] = None | |
| class WorldModelEnv: | |
| def __init__( | |
| self, | |
| denoiser: Denoiser, | |
| upsampler: Optional[Denoiser], | |
| rew_end_model: Optional[RewEndModel], | |
| spawn_dir: Path, | |
| num_envs: int, | |
| seq_length: int, | |
| cfg: WorldModelEnvConfig, | |
| return_denoising_trajectory: bool = False, | |
| ) -> None: | |
| assert num_envs == 1 # for csgo only | |
| self.sampler_next_obs = DiffusionSampler(denoiser, cfg.diffusion_sampler_next_obs) | |
| self.sampler_upsampling = None if upsampler is None else DiffusionSampler(upsampler, cfg.diffusion_sampler_upsampling) | |
| self.rew_end_model = rew_end_model | |
| self.horizon = cfg.horizon | |
| self.return_denoising_trajectory = return_denoising_trajectory | |
| self.num_envs = num_envs | |
| self.generator_init = self.make_generator_init(spawn_dir, cfg.num_batches_to_preload) | |
| self.n_skip_next_obs = seq_length - self.sampler_next_obs.denoiser.cfg.inner_model.num_steps_conditioning | |
| self.n_skip_upsampling = None if upsampler is None else seq_length - self.sampler_upsampling.denoiser.cfg.inner_model.num_steps_conditioning | |
| def device(self) -> torch.device: | |
| return self.sampler_next_obs.denoiser.device | |
| def reset(self, **kwargs) -> ResetOutput: | |
| obs, obs_full_res, act, next_act, (hx, cx) = self.generator_init.send(self.num_envs) | |
| self.obs_buffer = obs | |
| self.act_buffer = act | |
| self.next_act = next_act[0] | |
| self.obs_full_res_buffer = obs_full_res | |
| self.ep_len = torch.zeros(self.num_envs, dtype=torch.long, device=obs.device) | |
| self.hx_rew_end = hx | |
| self.cx_rew_end = cx | |
| obs_to_return = self.obs_buffer[:, -1] if self.sampler_upsampling is None else self.obs_full_res_buffer[:, -1] | |
| return obs_to_return, {} | |
| def step(self, act: torch.LongTensor) -> StepOutput: | |
| self.act_buffer[:, -1] = act | |
| next_obs, denoising_trajectory = self.predict_next_obs() | |
| if self.sampler_upsampling is not None: | |
| next_obs_full, denoising_trajectory_upsampling = self.upsample_next_obs(next_obs) | |
| if self.rew_end_model is not None: | |
| rew, end = self.predict_rew_end(next_obs.unsqueeze(1)) | |
| else: | |
| rew = torch.zeros(next_obs.size(0), dtype=torch.float32, device=self.device) | |
| end = torch.zeros(next_obs.size(0), dtype=torch.int64, device=self.device) | |
| self.ep_len += 1 | |
| trunc = (self.ep_len >= self.horizon).long() | |
| self.obs_buffer = self.obs_buffer.roll(-1, dims=1) | |
| self.act_buffer = self.act_buffer.roll(-1, dims=1) | |
| self.obs_buffer[:, -1] = next_obs | |
| if self.sampler_upsampling is not None: | |
| self.obs_full_res_buffer = self.obs_full_res_buffer.roll(-1, dims=1) | |
| self.obs_full_res_buffer[:, -1] = next_obs_full | |
| info = {} | |
| if self.return_denoising_trajectory: | |
| info["denoising_trajectory"] = torch.stack(denoising_trajectory, dim=1) | |
| if self.sampler_upsampling is not None: | |
| info["obs_low_res"] = next_obs | |
| if self.return_denoising_trajectory: | |
| info["denoising_trajectory_upsampling"] = torch.stack(denoising_trajectory_upsampling, dim=1) | |
| obs_to_return = self.obs_buffer[:, -1] if self.sampler_upsampling is None else self.obs_full_res_buffer[:, -1] | |
| return obs_to_return, rew, end, trunc, info | |
| def predict_next_obs(self) -> Tuple[Tensor, List[Tensor]]: | |
| return self.sampler_next_obs.sample(self.obs_buffer[:, self.n_skip_next_obs:], self.act_buffer[:, self.n_skip_next_obs:]) | |
| def upsample_next_obs(self, next_obs: Tensor) -> Tuple[Tensor, List[Tensor]]: | |
| low_res = F.interpolate(next_obs, scale_factor=self.sampler_upsampling.denoiser.cfg.upsampling_factor, mode="bicubic").unsqueeze(1) | |
| return self.sampler_upsampling.sample(torch.cat((self.obs_full_res_buffer[:, self.n_skip_upsampling:], low_res), dim=1), None) | |
| def predict_rew_end(self, next_obs: Tensor) -> Tuple[Tensor, Tensor]: | |
| logits_rew, logits_end, (self.hx_rew_end, self.cx_rew_end) = self.rew_end_model.predict_rew_end( | |
| self.obs_buffer[:, -1:], | |
| self.act_buffer[:, -1:], | |
| next_obs, | |
| (self.hx_rew_end, self.cx_rew_end), | |
| ) | |
| rew = Categorical(logits=logits_rew).sample().squeeze(1) - 1.0 # in {-1, 0, 1} | |
| end = Categorical(logits=logits_end).sample().squeeze(1) | |
| return rew, end | |
| def make_generator_init( | |
| self, | |
| spawn_dir: Path, | |
| num_batches_to_preload: int, | |
| ) -> Generator[InitialCondition, None, None]: | |
| num_dead = yield | |
| spawn_dirs = cycle(sorted(list(spawn_dir.iterdir()))) | |
| while True: | |
| # Preload on device and burnin rew/end model | |
| obs_, obs_full_res_, act_, next_act_, hx_, cx_ = [], [], [], [], [], [] | |
| for _ in range(num_batches_to_preload): | |
| d = next(spawn_dirs) | |
| obs = torch.tensor(np.load(d / "low_res.npy"), device=self.device).div(255).mul(2).sub(1).unsqueeze(0) | |
| obs_full_res = torch.tensor(np.load(d / "full_res.npy"), device=self.device).div(255).mul(2).sub(1).unsqueeze(0) | |
| act = torch.tensor(np.load(d / "act.npy"), dtype=torch.long, device=self.device).unsqueeze(0) | |
| next_act = torch.tensor(np.load(d / "next_act.npy"), dtype=torch.long, device=self.device).unsqueeze(0) | |
| obs_.extend(list(obs)) | |
| obs_full_res_.extend(list(obs_full_res)) | |
| act_.extend(list(act)) | |
| next_act_.extend(list(next_act)) | |
| if self.rew_end_model is not None: | |
| with torch.no_grad(): | |
| *_, (hx, cx) = self.rew_end_model.predict_rew_end(obs_[:, :-1], act[:, :-1], obs[:, 1:]) # Burn-in of rew/end model | |
| assert hx.size(0) == cx.size(0) == 1 | |
| hx_.extend(list(hx[0])) | |
| cx_.extend(list(cx[0])) | |
| # Yield new initial conditions for dead envs | |
| c = 0 | |
| while c + num_dead <= len(obs_): | |
| obs = torch.stack(obs_[c : c + num_dead]) | |
| act = torch.stack(act_[c : c + num_dead]) | |
| next_act = next_act_[c : c + num_dead] | |
| obs_full_res = torch.stack(obs_full_res_[c : c + num_dead]) if self.sampler_upsampling is not None else None | |
| hx = torch.stack(hx_[c : c + num_dead]).unsqueeze(0) if self.rew_end_model is not None else None | |
| cx = torch.stack(cx_[c : c + num_dead]).unsqueeze(0) if self.rew_end_model is not None else None | |
| c += num_dead | |
| num_dead = yield obs, obs_full_res, act, next_act, (hx, cx) | |