File size: 3,495 Bytes
f9f6093 23bc32f fb56df2 23bc32f fb56df2 23bc32f fb56df2 23bc32f fb56df2 23bc32f fb56df2 23bc32f fb56df2 23bc32f fb56df2 23bc32f fb56df2 23bc32f fb56df2 23bc32f fb56df2 23bc32f fb56df2 23bc32f fb56df2 23bc32f fb56df2 23bc32f fb56df2 23bc32f 72553e8 49a8d64 72553e8 fb56df2 23bc32f fb56df2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
from einops import rearrange
from einops.layers.torch import Rearrange
import torch
import torch.nn as nn
from .models.convnet import FrameEncoder
from .models.slicer import Head
from .models.transformer import TransformerEncoder
class WorldModel(nn.Module):
def __init__(self, config: dict) -> None:
super().__init__()
self.config = config
self.transformer = TransformerEncoder(config["transformer_config"])
assert ((config["image_size"] // 2 ** sum(config["frame_cnn_config"]["down"])) ** 2) * config["frame_cnn_config"]["latent_dim"] == config["transformer_config"]["embed_dim"]
self.frame_cnn = nn.Sequential(FrameEncoder(config["frame_cnn_config"]), Rearrange('b t c h w -> b t 1 (h w c)'), nn.LayerNorm(config["transformer_config"]["embed_dim"]))
self.act_emb = nn.Embedding(config["num_actions"], config["transformer_config"]["embed_dim"])
self.latents_emb = nn.Embedding(config["latent_vocab_size"], config["transformer_config"]["embed_dim"])
act_pattern = torch.zeros(config["transformer_config"]["tokens_per_block"])
act_pattern[1] = 1
act_and_latents_but_last_pattern = torch.zeros(config["transformer_config"]["tokens_per_block"])
act_and_latents_but_last_pattern[1:-1] = 1
self.head_latents = Head(
max_blocks=config["transformer_config"]["max_blocks"],
block_mask=act_and_latents_but_last_pattern,
head_module=nn.Sequential(
nn.Linear(config["transformer_config"]["embed_dim"], config["transformer_config"]["embed_dim"]), nn.ReLU(),
nn.Linear(config["transformer_config"]["embed_dim"], config["latent_vocab_size"])
)
)
self.head_rewards = Head(
max_blocks=config["transformer_config"]["max_blocks"],
block_mask=act_pattern,
head_module=nn.Sequential(
nn.Linear(config["transformer_config"]["embed_dim"], config["transformer_config"]["embed_dim"]), nn.ReLU(),
nn.Linear(config["transformer_config"]["embed_dim"], 255 if config["two_hot_rews"] else 3)
)
)
self.head_ends = Head(
max_blocks=config["transformer_config"]["max_blocks"],
block_mask=act_pattern,
head_module=nn.Sequential(
nn.Linear(config["transformer_config"]["embed_dim"], config["transformer_config"]["embed_dim"]), nn.ReLU(),
nn.Linear(config["transformer_config"]["embed_dim"], 2)
)
)
def __repr__(self) -> str:
return "world_model"
def blocks_left_in_kv_cache(self):
return self.transformer.num_blocks_left_in_kv_cache
def reset_kv_cache(self):
self.transformer.reset_kv_cache(n=1)
def forward(self, sequence: torch.FloatTensor, use_kv_cache: bool = False) -> dict:
prev_steps = self.transformer.keys_values.size if use_kv_cache else 0
num_steps = sequence.size(1)
outputs = self.transformer(sequence, use_kv_cache=use_kv_cache)
logits_latents = self.head_latents(outputs, num_steps, prev_steps)
logits_rewards = self.head_rewards(outputs, num_steps, prev_steps)
logits_ends = self.head_ends(outputs, num_steps, prev_steps)
return {
"output_sequence": outputs,
"logits_latents": logits_latents,
"logits_rewards": logits_rewards,
"logits_ends": logits_ends
}
|