import math from einops import rearrange import torch import torch.nn as nn from .models.convnet import FrameEncoder, FrameDecoder from .models.quantizer import Quantizer class Tokenizer(nn.Module): def __init__(self, config: dict) -> None: super().__init__() self.config = config self.latent_res = config["image_size"] // 2 ** sum(config["encoder_config"]["down"]) self.tokens_grid_res = int(math.sqrt(config["num_tokens"])) self.token_res = self.latent_res // self.tokens_grid_res self.encoder_act_emb = nn.Embedding(config["num_actions"], config["image_size"] ** 2) self.decoder_act_emb = nn.Embedding(config["num_actions"], config["decoder_act_channels"] * self.latent_res ** 2) self.quantizer = Quantizer( config["codebook_size"], config["codebook_dim"], input_dim=config["encoder_config"]["latent_dim"] * self.token_res ** 2, max_codebook_updates_with_revival=config["max_codebook_updates_with_revival"] ) self.encoder = FrameEncoder(config["encoder_config"]) self.decoder = FrameDecoder(config["decoder_config"]) self.frame_cnn = FrameEncoder(config["frame_cnn_config"]) def __repr__(self) -> str: return "tokenizer" def forward(self, x1: torch.FloatTensor, a: torch.LongTensor, x2: torch.FloatTensor) -> dict: z = self.encode(x1, a, x2) z = rearrange(z, 'b t c (h k) (w l) -> b t (h w) (k l c)', h=self.tokens_grid_res, w=self.tokens_grid_res) return self.quantizer(z) def encode(self, x1: torch.FloatTensor, a: torch.LongTensor, x2: torch.FloatTensor) -> torch.FloatTensor: a_emb = rearrange(self.encoder_act_emb(a), 'b t (h w) -> b t 1 h w', h=x1.size(3)) encoder_input = torch.cat((x1, a_emb, x2), dim=2) z = self.encoder(encoder_input) return z def decode(self, x1: torch.FloatTensor, a: torch.LongTensor, q2: torch.FloatTensor, should_clamp: bool = False) -> torch.FloatTensor: x1_emb = self.frame_cnn(x1) a_emb = rearrange(self.decoder_act_emb(a), 'b t (c h w) -> b t c h w', c=self.config["decoder_act_channels"], h=x1_emb.size(3)) decoder_input = torch.cat((x1_emb, a_emb, q2), dim=2) r = self.decoder(decoder_input) r = torch.clamp(r, 0, 1).mul(255).round().div(255) if should_clamp else r return r @torch.no_grad() def encode_decode(self, x1: torch.FloatTensor, a: torch.LongTensor, x2: torch.FloatTensor) -> torch.Tensor: z = self.encode(x1, a, x2) z = rearrange(z, 'b t c (h k) (w l) -> b t (h w) (k l c)', k=self.token_res, l=self.token_res) q = rearrange(self.quantizer(z).q, 'b t (h w) (k l e) -> b t e (h k) (w l)', h=self.tokens_grid_res, k=self.token_res, l=self.token_res) r = self.decode(x1, a, q, should_clamp=True) return r def embed_tokens(self, tokens): q = self.quantizer.embed_tokens(tokens) b, t, hw, kle = q.shape h = self.tokens_grid_res w = self.tokens_grid_res k = self.token_res l = self.token_res e = kle // (k * l) q = q.reshape(b, t, h, w, k, l, e) q = q.permute(0, 1, 6, 2, 4, 3, 5) q = q.reshape(b, t, e, h * k, w * l) return q @torch.no_grad() def burn_in(self, obs: torch.FloatTensor, act: torch.LongTensor) -> torch.LongTensor: assert obs.size(1) == act.size(1) + 1 quantizer_output = self(obs[:, :-1], act, obs[:, 1:]) return quantizer_output.tokens