File size: 4,201 Bytes
cf88ce4 484653f cf88ce4 d1d30ce cf88ce4 b7dc1ea cf88ce4 cb06601 cf88ce4 cb06601 cf88ce4 ccf803e a30bf83 ccf803e 7e213bf 8635508 cf88ce4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
from typing import Any, Optional, Tuple
from einops import rearrange
import torch
import torch.nn as nn
import torch.nn.functional as F
from .models.kv_caching import KeysValues
from .models.slicer import Embedder, Head
from .models.transformer import Transformer
class WorldModel(nn.Module):
def __init__(self, config: dict) -> None:
super().__init__()
self.obs_vocab_size, self.act_vocab_size = config["vocab_size"], config["act_vocab_size"]
self.config = config
self.transformer = Transformer(config)
all_but_last_obs_tokens_pattern = torch.ones(config["tokens_per_block"])
all_but_last_obs_tokens_pattern[-2] = 0
act_tokens_pattern = torch.zeros(self.config["tokens_per_block"])
act_tokens_pattern[-1] = 1
obs_tokens_pattern = 1 - act_tokens_pattern
self.pos_emb = nn.Embedding(config["max_tokens"], config["embed_dim"])
self.embedder = Embedder(
max_blocks=config["max_blocks"],
block_masks=[act_tokens_pattern, obs_tokens_pattern],
embedding_tables=nn.ModuleList([nn.Embedding(self.act_vocab_size, config["embed_dim"]), nn.Embedding(self.obs_vocab_size, config["embed_dim"])])
)
self.head_observations = Head(
max_blocks=config["max_blocks"],
block_mask=all_but_last_obs_tokens_pattern,
head_module=nn.Sequential(
nn.Linear(config["embed_dim"], config["embed_dim"]),
nn.ReLU(),
nn.Linear(config["embed_dim"], self.obs_vocab_size)
)
)
self.head_rewards = Head(
max_blocks=config["max_blocks"],
block_mask=act_tokens_pattern,
head_module=nn.Sequential(
nn.Linear(config["embed_dim"], config["embed_dim"]),
nn.ReLU(),
nn.Linear(config["embed_dim"], 3)
)
)
self.head_ends = Head(
max_blocks=config["max_blocks"],
block_mask=act_tokens_pattern,
head_module=nn.Sequential(
nn.Linear(config["embed_dim"], config["embed_dim"]),
nn.ReLU(),
nn.Linear(config["embed_dim"], 2)
)
)
def __repr__(self) -> str:
return "world_model"
def forward(self, tokens: torch.LongTensor, past_keys_values: Optional[KeysValues] = None) -> dict:
num_steps = tokens.size(1) # (B, T)
assert num_steps <= self.config["max_tokens"]
prev_steps = 0 if past_keys_values is None else past_keys_values.size
sequences = self.embedder(tokens, num_steps, prev_steps) + self.pos_emb(prev_steps + torch.arange(num_steps, device=tokens.device))
x = self.transformer(sequences, past_keys_values)
logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps)
logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps)
logits_ends = self.head_ends(x, num_steps=num_steps, prev_steps=prev_steps)
return {
"output_sequence": x,
"logits_observations": logits_observations,
"logits_rewards": logits_rewards,
"logits_ends": logits_ends
}
def generate_empty_keys_values(self, n= 1):
values = self.transformer.generate_empty_keys_values(n=n, max_tokens= self.config["max_tokens"])
return values
def compute_labels_world_model(self, obs_tokens: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
assert torch.all(ends.sum(dim=1) <= 1) # at most 1 done
mask_fill = torch.logical_not(mask_padding)
labels_observations = rearrange(obs_tokens.masked_fill(mask_fill.unsqueeze(-1).expand_as(obs_tokens), -100), 'b t k -> b (t k)')[:, 1:]
labels_rewards = (rewards.sign() + 1).masked_fill(mask_fill, -100).long() # Rewards clipped to {-1, 0, 1}
labels_ends = ends.masked_fill(mask_fill, -100)
return labels_observations.reshape(-1), labels_rewards.reshape(-1), labels_ends.reshape(-1)
|