File size: 4,201 Bytes
cf88ce4
 
 
 
 
 
 
484653f
 
 
cf88ce4
 
d1d30ce
cf88ce4
b7dc1ea
cf88ce4
 
 
 
 
 
 
 
 
 
 
 
 
 
cb06601
cf88ce4
 
 
 
 
 
 
 
cb06601
cf88ce4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ccf803e
a30bf83
ccf803e
7e213bf
8635508
cf88ce4
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from typing import Any, Optional, Tuple

from einops import rearrange
import torch
import torch.nn as nn
import torch.nn.functional as F

from .models.kv_caching import KeysValues
from .models.slicer import Embedder, Head
from .models.transformer import Transformer

class WorldModel(nn.Module):
    def __init__(self, config: dict) -> None:
        super().__init__()
        self.obs_vocab_size, self.act_vocab_size = config["vocab_size"], config["act_vocab_size"]
        self.config = config
        self.transformer = Transformer(config)

        all_but_last_obs_tokens_pattern = torch.ones(config["tokens_per_block"])
        all_but_last_obs_tokens_pattern[-2] = 0
        act_tokens_pattern = torch.zeros(self.config["tokens_per_block"])
        act_tokens_pattern[-1] = 1
        obs_tokens_pattern = 1 - act_tokens_pattern

        self.pos_emb = nn.Embedding(config["max_tokens"], config["embed_dim"])

        self.embedder = Embedder(
            max_blocks=config["max_blocks"],
            block_masks=[act_tokens_pattern, obs_tokens_pattern],
            embedding_tables=nn.ModuleList([nn.Embedding(self.act_vocab_size, config["embed_dim"]), nn.Embedding(self.obs_vocab_size, config["embed_dim"])])
        )

        self.head_observations = Head(
            max_blocks=config["max_blocks"],
            block_mask=all_but_last_obs_tokens_pattern,
            head_module=nn.Sequential(
                nn.Linear(config["embed_dim"], config["embed_dim"]),
                nn.ReLU(),
                nn.Linear(config["embed_dim"], self.obs_vocab_size)
            )
        )

        self.head_rewards = Head(
            max_blocks=config["max_blocks"],
            block_mask=act_tokens_pattern,
            head_module=nn.Sequential(
                nn.Linear(config["embed_dim"], config["embed_dim"]),
                nn.ReLU(),
                nn.Linear(config["embed_dim"], 3)
            )
        )

        self.head_ends = Head(
            max_blocks=config["max_blocks"],
            block_mask=act_tokens_pattern,
            head_module=nn.Sequential(
                nn.Linear(config["embed_dim"], config["embed_dim"]),
                nn.ReLU(),
                nn.Linear(config["embed_dim"], 2)
            )
        )

    def __repr__(self) -> str:
        return "world_model"

    def forward(self, tokens: torch.LongTensor, past_keys_values: Optional[KeysValues] = None) -> dict:

        num_steps = tokens.size(1)  # (B, T)
        assert num_steps <= self.config["max_tokens"]
        prev_steps = 0 if past_keys_values is None else past_keys_values.size

        sequences = self.embedder(tokens, num_steps, prev_steps) + self.pos_emb(prev_steps + torch.arange(num_steps, device=tokens.device))

        x = self.transformer(sequences, past_keys_values)

        logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps)
        logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps)
        logits_ends = self.head_ends(x, num_steps=num_steps, prev_steps=prev_steps)
        return {
            "output_sequence": x, 
            "logits_observations": logits_observations,
            "logits_rewards": logits_rewards,
            "logits_ends": logits_ends

        }

    def generate_empty_keys_values(self, n= 1):

        values = self.transformer.generate_empty_keys_values(n=n, max_tokens= self.config["max_tokens"])
        return values
    
    def compute_labels_world_model(self, obs_tokens: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        assert torch.all(ends.sum(dim=1) <= 1)  # at most 1 done
        mask_fill = torch.logical_not(mask_padding)
        labels_observations = rearrange(obs_tokens.masked_fill(mask_fill.unsqueeze(-1).expand_as(obs_tokens), -100), 'b t k -> b (t k)')[:, 1:]
        labels_rewards = (rewards.sign() + 1).masked_fill(mask_fill, -100).long()  # Rewards clipped to {-1, 0, 1}
        labels_ends = ends.masked_fill(mask_fill, -100)
        return labels_observations.reshape(-1), labels_rewards.reshape(-1), labels_ends.reshape(-1)