| from einops import rearrange | |
| from einops.layers.torch import Rearrange | |
| import torch | |
| import torch.nn as nn | |
| from .models.convnet import FrameEncoder | |
| from .models.slicer import Head | |
| from .models.transformer import TransformerEncoder | |
| class WorldModel(nn.Module): | |
| def __init__(self, config: dict) -> None: | |
| super().__init__() | |
| self.config = config | |
| self.transformer = TransformerEncoder(config["transformer_config"]) | |
| assert ((config["image_size"] // 2 ** sum(config["frame_cnn_config"]["down"])) ** 2) * config["frame_cnn_config"]["latent_dim"] == config["transformer_config"]["embed_dim"] | |
| self.frame_cnn = nn.Sequential(FrameEncoder(config["frame_cnn_config"]), Rearrange('b t c h w -> b t 1 (h w c)'), nn.LayerNorm(config["transformer_config"]["embed_dim"])) | |
| self.act_emb = nn.Embedding(config["num_actions"], config["transformer_config"]["embed_dim"]) | |
| self.latents_emb = nn.Embedding(config["latent_vocab_size"], config["transformer_config"]["embed_dim"]) | |
| act_pattern = torch.zeros(config["transformer_config"]["tokens_per_block"]) | |
| act_pattern[1] = 1 | |
| act_and_latents_but_last_pattern = torch.zeros(config["transformer_config"]["tokens_per_block"]) | |
| act_and_latents_but_last_pattern[1:-1] = 1 | |
| self.head_latents = Head( | |
| max_blocks=config["transformer_config"]["max_blocks"], | |
| block_mask=act_and_latents_but_last_pattern, | |
| head_module=nn.Sequential( | |
| nn.Linear(config["transformer_config"]["embed_dim"], config["transformer_config"]["embed_dim"]), nn.ReLU(), | |
| nn.Linear(config["transformer_config"]["embed_dim"], config["latent_vocab_size"]) | |
| ) | |
| ) | |
| self.head_rewards = Head( | |
| max_blocks=config["transformer_config"]["max_blocks"], | |
| block_mask=act_pattern, | |
| head_module=nn.Sequential( | |
| nn.Linear(config["transformer_config"]["embed_dim"], config["transformer_config"]["embed_dim"]), nn.ReLU(), | |
| nn.Linear(config["transformer_config"]["embed_dim"], 255 if config["two_hot_rews"] else 3) | |
| ) | |
| ) | |
| self.head_ends = Head( | |
| max_blocks=config["transformer_config"]["max_blocks"], | |
| block_mask=act_pattern, | |
| head_module=nn.Sequential( | |
| nn.Linear(config["transformer_config"]["embed_dim"], config["transformer_config"]["embed_dim"]), nn.ReLU(), | |
| nn.Linear(config["transformer_config"]["embed_dim"], 2) | |
| ) | |
| ) | |
| def __repr__(self) -> str: | |
| return "world_model" | |
| def blocks_left_in_kv_cache(self): | |
| return self.transformer.num_blocks_left_in_kv_cache | |
| def reset_kv_cache(self): | |
| self.transformer.reset_kv_cache(n=1) | |
| def forward(self, sequence: torch.FloatTensor, use_kv_cache: bool = False) -> dict: | |
| prev_steps = self.transformer.keys_values.size if use_kv_cache else 0 | |
| num_steps = sequence.size(1) | |
| outputs = self.transformer(sequence, use_kv_cache=use_kv_cache) | |
| logits_latents = self.head_latents(outputs, num_steps, prev_steps) | |
| logits_rewards = self.head_rewards(outputs, num_steps, prev_steps) | |
| logits_ends = self.head_ends(outputs, num_steps, prev_steps) | |
| return { | |
| "output_sequence": outputs, | |
| "logits_latents": logits_latents, | |
| "logits_rewards": logits_rewards, | |
| "logits_ends": logits_ends | |
| } | |