from typing import Any, Optional, Tuple from einops import rearrange import torch import torch.nn as nn import torch.nn.functional as F from .models.kv_caching import KeysValues from .models.slicer import Embedder, Head from .models.transformer import Transformer class WorldModel(nn.Module): def __init__(self, config: dict) -> None: super().__init__() self.obs_vocab_size, self.act_vocab_size = config["vocab_size"], config["act_vocab_size"] self.config = config self.transformer = Transformer(config) all_but_last_obs_tokens_pattern = torch.ones(config["tokens_per_block"]) all_but_last_obs_tokens_pattern[-2] = 0 act_tokens_pattern = torch.zeros(self.config["tokens_per_block"]) act_tokens_pattern[-1] = 1 obs_tokens_pattern = 1 - act_tokens_pattern self.pos_emb = nn.Embedding(config["max_tokens"], config["embed_dim"]) self.embedder = Embedder( max_blocks=config["max_blocks"], block_masks=[act_tokens_pattern, obs_tokens_pattern], embedding_tables=nn.ModuleList([nn.Embedding(self.act_vocab_size, config["embed_dim"]), nn.Embedding(self.obs_vocab_size, config["embed_dim"])]) ) self.head_observations = Head( max_blocks=config["max_blocks"], block_mask=all_but_last_obs_tokens_pattern, head_module=nn.Sequential( nn.Linear(config["embed_dim"], config["embed_dim"]), nn.ReLU(), nn.Linear(config["embed_dim"], self.obs_vocab_size) ) ) self.head_rewards = Head( max_blocks=config["max_blocks"], block_mask=act_tokens_pattern, head_module=nn.Sequential( nn.Linear(config["embed_dim"], config["embed_dim"]), nn.ReLU(), nn.Linear(config["embed_dim"], 3) ) ) self.head_ends = Head( max_blocks=config["max_blocks"], block_mask=act_tokens_pattern, head_module=nn.Sequential( nn.Linear(config["embed_dim"], config["embed_dim"]), nn.ReLU(), nn.Linear(config["embed_dim"], 2) ) ) def __repr__(self) -> str: return "world_model" def forward(self, tokens: torch.LongTensor, past_keys_values: Optional[KeysValues] = None) -> dict: num_steps = tokens.size(1) # (B, T) assert num_steps <= self.config["max_tokens"] prev_steps = 0 if past_keys_values is None else past_keys_values.size sequences = self.embedder(tokens, num_steps, prev_steps) + self.pos_emb(prev_steps + torch.arange(num_steps, device=tokens.device)) x = self.transformer(sequences, past_keys_values) logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps) logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps) logits_ends = self.head_ends(x, num_steps=num_steps, prev_steps=prev_steps) return { "output_sequence": x, "logits_observations": logits_observations, "logits_rewards": logits_rewards, "logits_ends": logits_ends } def generate_empty_keys_values(self, n= 1): values = self.transformer.generate_empty_keys_values(n=n, max_tokens= self.config["max_tokens"]) return values def compute_labels_world_model(self, obs_tokens: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: assert torch.all(ends.sum(dim=1) <= 1) # at most 1 done mask_fill = torch.logical_not(mask_padding) labels_observations = rearrange(obs_tokens.masked_fill(mask_fill.unsqueeze(-1).expand_as(obs_tokens), -100), 'b t k -> b (t k)')[:, 1:] labels_rewards = (rewards.sign() + 1).masked_fill(mask_fill, -100).long() # Rewards clipped to {-1, 0, 1} labels_ends = ends.masked_fill(mask_fill, -100) return labels_observations.reshape(-1), labels_rewards.reshape(-1), labels_ends.reshape(-1)