sai_wm / delta-iris /src /world_model.py
SaiResearch's picture
Upload 11 files
2492172 verified
from einops import rearrange
from einops.layers.torch import Rearrange
import torch
import torch.nn as nn
from .models.convnet import FrameEncoder
from .models.slicer import Head
from .models.transformer import TransformerEncoder
class WorldModel(nn.Module):
def __init__(self, config: dict) -> None:
super().__init__()
self.config = config
self.transformer = TransformerEncoder(config["transformer_config"])
assert ((config["image_size"] // 2 ** sum(config["frame_cnn_config"]["down"])) ** 2) * config["frame_cnn_config"]["latent_dim"] == config["transformer_config"]["embed_dim"]
self.frame_cnn = nn.Sequential(FrameEncoder(config["frame_cnn_config"]), Rearrange('b t c h w -> b t 1 (h w c)'), nn.LayerNorm(config["transformer_config"]["embed_dim"]))
self.act_emb = nn.Embedding(config["num_actions"], config["transformer_config"]["embed_dim"])
self.latents_emb = nn.Embedding(config["latent_vocab_size"], config["transformer_config"]["embed_dim"])
act_pattern = torch.zeros(config["transformer_config"]["tokens_per_block"])
act_pattern[1] = 1
act_and_latents_but_last_pattern = torch.zeros(config["transformer_config"]["tokens_per_block"])
act_and_latents_but_last_pattern[1:-1] = 1
self.head_latents = Head(
max_blocks=config["transformer_config"]["max_blocks"],
block_mask=act_and_latents_but_last_pattern,
head_module=nn.Sequential(
nn.Linear(config["transformer_config"]["embed_dim"], config["transformer_config"]["embed_dim"]), nn.ReLU(),
nn.Linear(config["transformer_config"]["embed_dim"], config["latent_vocab_size"])
)
)
self.head_rewards = Head(
max_blocks=config["transformer_config"]["max_blocks"],
block_mask=act_pattern,
head_module=nn.Sequential(
nn.Linear(config["transformer_config"]["embed_dim"], config["transformer_config"]["embed_dim"]), nn.ReLU(),
nn.Linear(config["transformer_config"]["embed_dim"], 255 if config["two_hot_rews"] else 3)
)
)
self.head_ends = Head(
max_blocks=config["transformer_config"]["max_blocks"],
block_mask=act_pattern,
head_module=nn.Sequential(
nn.Linear(config["transformer_config"]["embed_dim"], config["transformer_config"]["embed_dim"]), nn.ReLU(),
nn.Linear(config["transformer_config"]["embed_dim"], 2)
)
)
def __repr__(self) -> str:
return "world_model"
def blocks_left_in_kv_cache(self):
return self.transformer.num_blocks_left_in_kv_cache
def reset_kv_cache(self):
self.transformer.reset_kv_cache(n=1)
def forward(self, sequence: torch.FloatTensor, use_kv_cache: bool = False) -> dict:
prev_steps = self.transformer.keys_values.size if use_kv_cache else 0
num_steps = sequence.size(1)
outputs = self.transformer(sequence, use_kv_cache=use_kv_cache)
logits_latents = self.head_latents(outputs, num_steps, prev_steps)
logits_rewards = self.head_rewards(outputs, num_steps, prev_steps)
logits_ends = self.head_ends(outputs, num_steps, prev_steps)
return {
"output_sequence": outputs,
"logits_latents": logits_latents,
"logits_rewards": logits_rewards,
"logits_ends": logits_ends
}