""" Credits to https://github.com/CompVis/taming-transformers """ from typing import Tuple from einops import rearrange import torch import torch.nn as nn from .models.lpips import LPIPS from .models.nets import Encoder, Decoder class Tokenizer(nn.Module): def __init__(self, config: dict, with_lpips: bool = True) -> None: super().__init__() self.vocab_size = config["vocab_size"] self.embed_dim = config["embed_dim"] self.encoder = Encoder(config["encoder"]) self.pre_quant_conv = torch.nn.Conv2d(encoder.config.z_channels, self.embed_dim, 1) self.embedding = nn.Embedding(self.vocab_size, self.embed_dim) self.post_quant_conv = torch.nn.Conv2d(self.embed_dim, decoder.config.z_channels, 1) self.decoder = Decoder(config["decoder"]) self.embedding.weight.data.uniform_(-1.0 / self.vocab_size, 1.0 / self.vocab_size) self.lpips = LPIPS().eval() if with_lpips else None def __repr__(self) -> str: return "tokenizer" def forward(self, x: torch.Tensor, should_preprocess: bool = False, should_postprocess: bool = False) -> Tuple[torch.Tensor]: outputs = self.encode(x, should_preprocess) decoder_input = outputs.z + (outputs.z_quantized - outputs.z).detach() reconstructions = self.decode(decoder_input, should_postprocess) return outputs.z, outputs.z_quantized, reconstructions def encode(self, x: torch.Tensor, should_preprocess: bool = False) -> dict: if should_preprocess: x = self.preprocess_input(x) shape = x.shape # (..., C, H, W) x = x.view(-1, *shape[-3:]) z = self.encoder(x) z = self.pre_quant_conv(z) b, e, h, w = z.shape z_flattened = rearrange(z, 'b e h w -> (b h w) e') dist_to_embeddings = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + torch.sum(self.embedding.weight**2, dim=1) - 2 * torch.matmul(z_flattened, self.embedding.weight.t()) tokens = dist_to_embeddings.argmin(dim=-1) z_q = rearrange(self.embedding(tokens), '(b h w) e -> b e h w', b=b, e=e, h=h, w=w).contiguous() # Reshape to original z = z.reshape(*shape[:-3], *z.shape[1:]) z_q = z_q.reshape(*shape[:-3], *z_q.shape[1:]) tokens = tokens.reshape(*shape[:-3], -1) return { "z": z, "z_quantized": z_q, "tokens": tokens } def decode(self, z_q: torch.Tensor, should_postprocess: bool = False) -> torch.Tensor: shape = z_q.shape # (..., E, h, w) z_q = z_q.view(-1, *shape[-3:]) z_q = self.post_quant_conv(z_q) rec = self.decoder(z_q) rec = rec.reshape(*shape[:-3], *rec.shape[1:]) if should_postprocess: rec = self.postprocess_output(rec) return rec @torch.no_grad() def encode_decode(self, x: torch.Tensor, should_preprocess: bool = False, should_postprocess: bool = False) -> torch.Tensor: z_q = self.encode(x, should_preprocess).z_quantized return self.decode(z_q, should_postprocess) def preprocess_input(self, x: torch.Tensor) -> torch.Tensor: """x is supposed to be channels first and in [0, 1]""" return x.mul(2).sub(1) def postprocess_output(self, y: torch.Tensor) -> torch.Tensor: """y is supposed to be channels first and in [-1, 1]""" return y.add(1).div(2)