ShaswatRobotics's picture
Update delta-iris/src/world_model.py
49a8d64 verified
from einops import rearrange
from einops.layers.torch import Rearrange
import torch
import torch.nn as nn
from .models.convnet import FrameEncoder
from .models.slicer import Head
from .models.transformer import TransformerEncoder
class WorldModel(nn.Module):
def __init__(self, config: dict) -> None:
super().__init__()
self.config = config
self.transformer = TransformerEncoder(config["transformer_config"])
assert ((config["image_size"] // 2 ** sum(config["frame_cnn_config"]["down"])) ** 2) * config["frame_cnn_config"]["latent_dim"] == config["transformer_config"]["embed_dim"]
self.frame_cnn = nn.Sequential(FrameEncoder(config["frame_cnn_config"]), Rearrange('b t c h w -> b t 1 (h w c)'), nn.LayerNorm(config["transformer_config"]["embed_dim"]))
self.act_emb = nn.Embedding(config["num_actions"], config["transformer_config"]["embed_dim"])
self.latents_emb = nn.Embedding(config["latent_vocab_size"], config["transformer_config"]["embed_dim"])
act_pattern = torch.zeros(config["transformer_config"]["tokens_per_block"])
act_pattern[1] = 1
act_and_latents_but_last_pattern = torch.zeros(config["transformer_config"]["tokens_per_block"])
act_and_latents_but_last_pattern[1:-1] = 1
self.head_latents = Head(
max_blocks=config["transformer_config"]["max_blocks"],
block_mask=act_and_latents_but_last_pattern,
head_module=nn.Sequential(
nn.Linear(config["transformer_config"]["embed_dim"], config["transformer_config"]["embed_dim"]), nn.ReLU(),
nn.Linear(config["transformer_config"]["embed_dim"], config["latent_vocab_size"])
)
)
self.head_rewards = Head(
max_blocks=config["transformer_config"]["max_blocks"],
block_mask=act_pattern,
head_module=nn.Sequential(
nn.Linear(config["transformer_config"]["embed_dim"], config["transformer_config"]["embed_dim"]), nn.ReLU(),
nn.Linear(config["transformer_config"]["embed_dim"], 255 if config["two_hot_rews"] else 3)
)
)
self.head_ends = Head(
max_blocks=config["transformer_config"]["max_blocks"],
block_mask=act_pattern,
head_module=nn.Sequential(
nn.Linear(config["transformer_config"]["embed_dim"], config["transformer_config"]["embed_dim"]), nn.ReLU(),
nn.Linear(config["transformer_config"]["embed_dim"], 2)
)
)
def __repr__(self) -> str:
return "world_model"
def blocks_left_in_kv_cache(self):
return self.transformer.num_blocks_left_in_kv_cache
def reset_kv_cache(self):
self.transformer.reset_kv_cache(n=1)
def forward(self, sequence: torch.FloatTensor, use_kv_cache: bool = False) -> dict:
prev_steps = self.transformer.keys_values.size if use_kv_cache else 0
num_steps = sequence.size(1)
outputs = self.transformer(sequence, use_kv_cache=use_kv_cache)
logits_latents = self.head_latents(outputs, num_steps, prev_steps)
logits_rewards = self.head_rewards(outputs, num_steps, prev_steps)
logits_ends = self.head_ends(outputs, num_steps, prev_steps)
return {
"output_sequence": outputs,
"logits_latents": logits_latents,
"logits_rewards": logits_rewards,
"logits_ends": logits_ends
}