File size: 3,495 Bytes
f9f6093
23bc32f
 
 
 
fb56df2
23bc32f
fb56df2
23bc32f
 
fb56df2
23bc32f
 
fb56df2
23bc32f
fb56df2
 
23bc32f
fb56df2
 
23bc32f
fb56df2
23bc32f
fb56df2
23bc32f
 
 
fb56df2
23bc32f
 
fb56df2
 
23bc32f
 
 
 
fb56df2
23bc32f
 
fb56df2
 
23bc32f
 
 
 
fb56df2
23bc32f
 
fb56df2
 
23bc32f
 
 
 
 
 
72553e8
 
 
 
49a8d64
72553e8
fb56df2
23bc32f
 
 
 
 
 
 
 
 
fb56df2
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from einops import rearrange
from einops.layers.torch import Rearrange
import torch
import torch.nn as nn

from .models.convnet import FrameEncoder
from .models.slicer import  Head
from .models.transformer import TransformerEncoder

class WorldModel(nn.Module):
    def __init__(self, config: dict) -> None:
        super().__init__()
        self.config = config
        self.transformer = TransformerEncoder(config["transformer_config"])

        assert ((config["image_size"] // 2 ** sum(config["frame_cnn_config"]["down"])) ** 2) * config["frame_cnn_config"]["latent_dim"] == config["transformer_config"]["embed_dim"]
        self.frame_cnn = nn.Sequential(FrameEncoder(config["frame_cnn_config"]), Rearrange('b t c h w -> b t 1 (h w c)'), nn.LayerNorm(config["transformer_config"]["embed_dim"]))

        self.act_emb = nn.Embedding(config["num_actions"], config["transformer_config"]["embed_dim"])
        self.latents_emb = nn.Embedding(config["latent_vocab_size"], config["transformer_config"]["embed_dim"])

        act_pattern = torch.zeros(config["transformer_config"]["tokens_per_block"])
        act_pattern[1] = 1
        act_and_latents_but_last_pattern = torch.zeros(config["transformer_config"]["tokens_per_block"]) 
        act_and_latents_but_last_pattern[1:-1] = 1

        self.head_latents = Head(
            max_blocks=config["transformer_config"]["max_blocks"],
            block_mask=act_and_latents_but_last_pattern,
            head_module=nn.Sequential(
                nn.Linear(config["transformer_config"]["embed_dim"], config["transformer_config"]["embed_dim"]), nn.ReLU(),
                nn.Linear(config["transformer_config"]["embed_dim"], config["latent_vocab_size"])
            )
        )

        self.head_rewards = Head(
            max_blocks=config["transformer_config"]["max_blocks"],
            block_mask=act_pattern,
            head_module=nn.Sequential(
                nn.Linear(config["transformer_config"]["embed_dim"], config["transformer_config"]["embed_dim"]), nn.ReLU(),
                nn.Linear(config["transformer_config"]["embed_dim"], 255 if config["two_hot_rews"] else 3)
            )
        )

        self.head_ends = Head(
            max_blocks=config["transformer_config"]["max_blocks"],
            block_mask=act_pattern,
            head_module=nn.Sequential(
                nn.Linear(config["transformer_config"]["embed_dim"], config["transformer_config"]["embed_dim"]), nn.ReLU(),
                nn.Linear(config["transformer_config"]["embed_dim"], 2)
            )
        )

    def __repr__(self) -> str:
        return "world_model"

    def blocks_left_in_kv_cache(self):
        return self.transformer.num_blocks_left_in_kv_cache

    def reset_kv_cache(self):
        self.transformer.reset_kv_cache(n=1)

    def forward(self, sequence: torch.FloatTensor, use_kv_cache: bool = False) -> dict:      
        prev_steps = self.transformer.keys_values.size if use_kv_cache else 0
        num_steps = sequence.size(1)

        outputs = self.transformer(sequence, use_kv_cache=use_kv_cache)

        logits_latents = self.head_latents(outputs, num_steps, prev_steps)
        logits_rewards = self.head_rewards(outputs, num_steps, prev_steps)
        logits_ends = self.head_ends(outputs, num_steps, prev_steps)

        return {
            "output_sequence": outputs, 
            "logits_latents": logits_latents,
            "logits_rewards": logits_rewards,
            "logits_ends": logits_ends
        }