""" Credits to https://github.com/CompVis/taming-transformers """ from typing import Tuple import numpy as np from einops import rearrange import torch import torch.nn as nn from .models.lpips import LPIPS from .models.nets import Encoder, Decoder class Tokenizer(nn.Module): def __init__(self, config: dict, with_lpips: bool = True) -> None: super().__init__() self.vocab_size = config["vocab_size"] self.embed_dim = config["embed_dim"] self.encoder = Encoder(config["encoder"]) self.decoder = Decoder(config["decoder"]) self.pre_quant_conv = torch.nn.Conv2d(self.encoder.config["z_channels"], self.embed_dim, 1) self.embedding = nn.Embedding(self.vocab_size, self.embed_dim) self.post_quant_conv = torch.nn.Conv2d(self.embed_dim, self.decoder.config["z_channels"], 1) self.embedding.weight.data.uniform_(-1.0 / self.vocab_size, 1.0 / self.vocab_size) self.lpips = LPIPS().eval() if with_lpips else None def __repr__(self) -> str: return "tokenizer" def forward(self, x: torch.Tensor, should_preprocess: bool = False, should_postprocess: bool = False) -> Tuple[torch.Tensor]: outputs = self.encode(x, should_preprocess) decoder_input = outputs["z"] + (outputs["z_quantized"] - outputs["z"]).detach() reconstructions = self.decode(decoder_input, should_postprocess) return outputs["z"], outputs["z_quantized"], reconstructions def encode(self, x: torch.Tensor, should_preprocess: bool = False) -> dict: if should_preprocess: x = self.preprocess_input(x) shape = x.shape # (..., C, H, W) x = x.view(-1, *shape[-3:]) z = self.encoder(x) z = self.pre_quant_conv(z) b, e, h, w = z.shape z_flattened = rearrange(z, 'b e h w -> (b h w) e') dist_to_embeddings = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + torch.sum(self.embedding.weight**2, dim=1) - 2 * torch.matmul(z_flattened, self.embedding.weight.t()) tokens = dist_to_embeddings.argmin(dim=-1) z_q = rearrange(self.embedding(tokens), '(b h w) e -> b e h w', b=b, e=e, h=h, w=w).contiguous() # Reshape to original z = z.reshape(*shape[:-3], *z.shape[1:]) z_q = z_q.reshape(*shape[:-3], *z_q.shape[1:]) tokens = tokens.reshape(*shape[:-3], -1) return { "z": z, "z_quantized": z_q, "tokens": tokens } def decode(self, z_q: torch.Tensor, should_postprocess: bool = False) -> torch.Tensor: shape = z_q.shape # (..., E, h, w) z_q = z_q.view(-1, *shape[-3:]) z_q = self.post_quant_conv(z_q) rec = self.decoder(z_q) rec = rec.reshape(*shape[:-3], *rec.shape[1:]) if should_postprocess: rec = self.postprocess_output(rec) return rec def decode_obs_tokens(self, obs_tokens, num_observations_tokens): embedded_tokens = self.embedding(obs_tokens) # (B, K, E) z = rearrange(embedded_tokens, 'b (h w) e -> b e h w', h=int(np.sqrt(num_observations_tokens))) rec = self.decode(z, should_postprocess=True) # (B, C, H, W) return torch.clamp(rec, 0, 1) @torch.no_grad() def encode_decode(self, x: torch.Tensor, should_preprocess: bool = False, should_postprocess: bool = False) -> torch.Tensor: z_q = self.encode(x, should_preprocess).z_quantized return self.decode(z_q, should_postprocess) def preprocess_input(self, x: torch.Tensor) -> torch.Tensor: """x is supposed to be channels first and in [0, 1]""" return x.mul(2).sub(1) def postprocess_output(self, y: torch.Tensor) -> torch.Tensor: """y is supposed to be channels first and in [-1, 1]""" return y.add(1).div(2)