from einops import rearrange from einops.layers.torch import Rearrange import torch import torch.nn as nn from .models.convnet import FrameEncoder from .models.slicer import Head from .models.transformer import TransformerEncoder class WorldModel(nn.Module): def __init__(self, config: dict) -> None: super().__init__() self.config = config self.transformer = TransformerEncoder(config["transformer_config"]) assert ((config["image_size"] // 2 ** sum(config["frame_cnn_config"]["down"])) ** 2) * config["frame_cnn_config"]["latent_dim"] == config["transformer_config"]["embed_dim"] self.frame_cnn = nn.Sequential(FrameEncoder(config["frame_cnn_config"]), Rearrange('b t c h w -> b t 1 (h w c)'), nn.LayerNorm(config["transformer_config"]["embed_dim"])) self.act_emb = nn.Embedding(config["num_actions"], config["transformer_config"]["embed_dim"]) self.latents_emb = nn.Embedding(config["latent_vocab_size"], config["transformer_config"]["embed_dim"]) act_pattern = torch.zeros(config["transformer_config"]["tokens_per_block"]) act_pattern[1] = 1 act_and_latents_but_last_pattern = torch.zeros(config["transformer_config"]["tokens_per_block"]) act_and_latents_but_last_pattern[1:-1] = 1 self.head_latents = Head( max_blocks=config["transformer_config"]["max_blocks"], block_mask=act_and_latents_but_last_pattern, head_module=nn.Sequential( nn.Linear(config["transformer_config"]["embed_dim"], config["transformer_config"]["embed_dim"]), nn.ReLU(), nn.Linear(config["transformer_config"]["embed_dim"], config["latent_vocab_size"]) ) ) self.head_rewards = Head( max_blocks=config["transformer_config"]["max_blocks"], block_mask=act_pattern, head_module=nn.Sequential( nn.Linear(config["transformer_config"]["embed_dim"], config["transformer_config"]["embed_dim"]), nn.ReLU(), nn.Linear(config["transformer_config"]["embed_dim"], 255 if config["two_hot_rews"] else 3) ) ) self.head_ends = Head( max_blocks=config["transformer_config"]["max_blocks"], block_mask=act_pattern, head_module=nn.Sequential( nn.Linear(config["transformer_config"]["embed_dim"], config["transformer_config"]["embed_dim"]), nn.ReLU(), nn.Linear(config["transformer_config"]["embed_dim"], 2) ) ) def __repr__(self) -> str: return "world_model" def blocks_left_in_kv_cache(self): return self.transformer.num_blocks_left_in_kv_cache def reset_kv_cache(self): self.transformer.reset_kv_cache(n=1) def forward(self, sequence: torch.FloatTensor, use_kv_cache: bool = False) -> dict: prev_steps = self.transformer.keys_values.size if use_kv_cache else 0 num_steps = sequence.size(1) outputs = self.transformer(sequence, use_kv_cache=use_kv_cache) logits_latents = self.head_latents(outputs, num_steps, prev_steps) logits_rewards = self.head_rewards(outputs, num_steps, prev_steps) logits_ends = self.head_ends(outputs, num_steps, prev_steps) return { "output_sequence": outputs, "logits_latents": logits_latents, "logits_rewards": logits_rewards, "logits_ends": logits_ends }