| from typing import Any, Optional, Tuple |
|
|
| from einops import rearrange |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from .models.kv_caching import KeysValues |
| from .models.slicer import Embedder, Head |
| from .models.transformer import Transformer |
|
|
| class WorldModel(nn.Module): |
| def __init__(self, config: dict) -> None: |
| super().__init__() |
| self.obs_vocab_size, self.act_vocab_size = config["vocab_size"], config["act_vocab_size"] |
| self.config = config |
| self.transformer = Transformer(config) |
|
|
| all_but_last_obs_tokens_pattern = torch.ones(config["tokens_per_block"]) |
| all_but_last_obs_tokens_pattern[-2] = 0 |
| act_tokens_pattern = torch.zeros(self.config["tokens_per_block"]) |
| act_tokens_pattern[-1] = 1 |
| obs_tokens_pattern = 1 - act_tokens_pattern |
|
|
| self.pos_emb = nn.Embedding(config["max_tokens"], config["embed_dim"]) |
|
|
| self.embedder = Embedder( |
| max_blocks=config["max_blocks"], |
| block_masks=[act_tokens_pattern, obs_tokens_pattern], |
| embedding_tables=nn.ModuleList([nn.Embedding(self.act_vocab_size, config["embed_dim"]), nn.Embedding(self.obs_vocab_size, config["embed_dim"])]) |
| ) |
|
|
| self.head_observations = Head( |
| max_blocks=config["max_blocks"], |
| block_mask=all_but_last_obs_tokens_pattern, |
| head_module=nn.Sequential( |
| nn.Linear(config["embed_dim"], config["embed_dim"]), |
| nn.ReLU(), |
| nn.Linear(config["embed_dim"], self.obs_vocab_size) |
| ) |
| ) |
|
|
| self.head_rewards = Head( |
| max_blocks=config["max_blocks"], |
| block_mask=act_tokens_pattern, |
| head_module=nn.Sequential( |
| nn.Linear(config["embed_dim"], config["embed_dim"]), |
| nn.ReLU(), |
| nn.Linear(config["embed_dim"], 3) |
| ) |
| ) |
|
|
| self.head_ends = Head( |
| max_blocks=config["max_blocks"], |
| block_mask=act_tokens_pattern, |
| head_module=nn.Sequential( |
| nn.Linear(config["embed_dim"], config["embed_dim"]), |
| nn.ReLU(), |
| nn.Linear(config["embed_dim"], 2) |
| ) |
| ) |
|
|
| def __repr__(self) -> str: |
| return "world_model" |
|
|
| def forward(self, tokens: torch.LongTensor, past_keys_values: Optional[KeysValues] = None) -> dict: |
|
|
| num_steps = tokens.size(1) |
| assert num_steps <= self.config["max_tokens"] |
| prev_steps = 0 if past_keys_values is None else past_keys_values.size |
|
|
| sequences = self.embedder(tokens, num_steps, prev_steps) + self.pos_emb(prev_steps + torch.arange(num_steps, device=tokens.device)) |
|
|
| x = self.transformer(sequences, past_keys_values) |
|
|
| logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps) |
| logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps) |
| logits_ends = self.head_ends(x, num_steps=num_steps, prev_steps=prev_steps) |
| return { |
| "output_sequence": x, |
| "logits_observations": logits_observations, |
| "logits_rewards": logits_rewards, |
| "logits_ends": logits_ends |
|
|
| } |
|
|
| def generate_empty_keys_values(self, n= 1): |
|
|
| values = self.transformer.generate_empty_keys_values(n=n, max_tokens= self.config["max_tokens"]) |
| return values |
| |
| def compute_labels_world_model(self, obs_tokens: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| assert torch.all(ends.sum(dim=1) <= 1) |
| mask_fill = torch.logical_not(mask_padding) |
| labels_observations = rearrange(obs_tokens.masked_fill(mask_fill.unsqueeze(-1).expand_as(obs_tokens), -100), 'b t k -> b (t k)')[:, 1:] |
| labels_rewards = (rewards.sign() + 1).masked_fill(mask_fill, -100).long() |
| labels_ends = ends.masked_fill(mask_fill, -100) |
| return labels_observations.reshape(-1), labels_rewards.reshape(-1), labels_ends.reshape(-1) |
|
|