|
|
from typing import Any, Optional, Tuple |
|
|
|
|
|
from einops import rearrange |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from .models.kv_caching import KeysValues |
|
|
from .models.slicer import Embedder, Head |
|
|
from .models.transformer import Transformer |
|
|
|
|
|
class WorldModel(nn.Module): |
|
|
def __init__(self, config: dict) -> None: |
|
|
super().__init__() |
|
|
self.obs_vocab_size, self.act_vocab_size = config["vocab_size"], config["act_vocab_size"] |
|
|
self.config = config |
|
|
self.transformer = Transformer(config) |
|
|
|
|
|
all_but_last_obs_tokens_pattern = torch.ones(config["tokens_per_block"]) |
|
|
all_but_last_obs_tokens_pattern[-2] = 0 |
|
|
act_tokens_pattern = torch.zeros(self.config["tokens_per_block"]) |
|
|
act_tokens_pattern[-1] = 1 |
|
|
obs_tokens_pattern = 1 - act_tokens_pattern |
|
|
|
|
|
self.pos_emb = nn.Embedding(config["max_tokens"], config["embed_dim"]) |
|
|
|
|
|
self.embedder = Embedder( |
|
|
max_blocks=config["max_blocks"], |
|
|
block_masks=[act_tokens_pattern, obs_tokens_pattern], |
|
|
embedding_tables=nn.ModuleList([nn.Embedding(self.act_vocab_size, config["embed_dim"]), nn.Embedding(self.obs_vocab_size, config["embed_dim"])]) |
|
|
) |
|
|
|
|
|
self.head_observations = Head( |
|
|
max_blocks=config["max_blocks"], |
|
|
block_mask=all_but_last_obs_tokens_pattern, |
|
|
head_module=nn.Sequential( |
|
|
nn.Linear(config["embed_dim"], config["embed_dim"]), |
|
|
nn.ReLU(), |
|
|
nn.Linear(config["embed_dim"], self.obs_vocab_size) |
|
|
) |
|
|
) |
|
|
|
|
|
self.head_rewards = Head( |
|
|
max_blocks=config["max_blocks"], |
|
|
block_mask=act_tokens_pattern, |
|
|
head_module=nn.Sequential( |
|
|
nn.Linear(config["embed_dim"], config["embed_dim"]), |
|
|
nn.ReLU(), |
|
|
nn.Linear(config["embed_dim"], 3) |
|
|
) |
|
|
) |
|
|
|
|
|
self.head_ends = Head( |
|
|
max_blocks=config["max_blocks"], |
|
|
block_mask=act_tokens_pattern, |
|
|
head_module=nn.Sequential( |
|
|
nn.Linear(config["embed_dim"], config["embed_dim"]), |
|
|
nn.ReLU(), |
|
|
nn.Linear(config["embed_dim"], 2) |
|
|
) |
|
|
) |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return "world_model" |
|
|
|
|
|
def forward(self, tokens: torch.LongTensor, past_keys_values: Optional[KeysValues] = None) -> dict: |
|
|
|
|
|
num_steps = tokens.size(1) |
|
|
assert num_steps <= self.config["max_tokens"] |
|
|
prev_steps = 0 if past_keys_values is None else past_keys_values.size |
|
|
|
|
|
sequences = self.embedder(tokens, num_steps, prev_steps) + self.pos_emb(prev_steps + torch.arange(num_steps, device=tokens.device)) |
|
|
|
|
|
x = self.transformer(sequences, past_keys_values) |
|
|
|
|
|
logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps) |
|
|
logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps) |
|
|
logits_ends = self.head_ends(x, num_steps=num_steps, prev_steps=prev_steps) |
|
|
return { |
|
|
"output_sequence": x, |
|
|
"logits_observations": logits_observations, |
|
|
"logits_rewards": logits_rewards, |
|
|
"logits_ends": logits_ends |
|
|
|
|
|
} |
|
|
|
|
|
def generate_empty_keys_values(self, n= 1): |
|
|
|
|
|
values = self.transformer.generate_empty_keys_values(n=n, max_tokens= self.config["max_tokens"]) |
|
|
return values |
|
|
|
|
|
def compute_labels_world_model(self, obs_tokens: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
assert torch.all(ends.sum(dim=1) <= 1) |
|
|
mask_fill = torch.logical_not(mask_padding) |
|
|
labels_observations = rearrange(obs_tokens.masked_fill(mask_fill.unsqueeze(-1).expand_as(obs_tokens), -100), 'b t k -> b (t k)')[:, 1:] |
|
|
labels_rewards = (rewards.sign() + 1).masked_fill(mask_fill, -100).long() |
|
|
labels_ends = ends.masked_fill(mask_fill, -100) |
|
|
return labels_observations.reshape(-1), labels_rewards.reshape(-1), labels_ends.reshape(-1) |
|
|
|