| | """ |
| | Credits to https://github.com/CompVis/taming-transformers |
| | """ |
| |
|
| | from typing import Tuple |
| |
|
| | from einops import rearrange |
| | import torch |
| | import torch.nn as nn |
| |
|
| | from models.lpips import LPIPS |
| | from models.nets import Encoder, Decoder |
| |
|
| | class Tokenizer(nn.Module): |
| | def __init__(self, vocab_size: int, embed_dim: int, encoder: Encoder, decoder: Decoder, with_lpips: bool = True) -> None: |
| | super().__init__() |
| | self.vocab_size = vocab_size |
| | self.encoder = encoder |
| | self.pre_quant_conv = torch.nn.Conv2d(encoder.config.z_channels, embed_dim, 1) |
| | self.embedding = nn.Embedding(vocab_size, embed_dim) |
| | self.post_quant_conv = torch.nn.Conv2d(embed_dim, decoder.config.z_channels, 1) |
| | self.decoder = decoder |
| | self.embedding.weight.data.uniform_(-1.0 / vocab_size, 1.0 / vocab_size) |
| | self.lpips = LPIPS().eval() if with_lpips else None |
| |
|
| | def __repr__(self) -> str: |
| | return "tokenizer" |
| |
|
| | def forward(self, x: torch.Tensor, should_preprocess: bool = False, should_postprocess: bool = False) -> Tuple[torch.Tensor]: |
| | outputs = self.encode(x, should_preprocess) |
| | decoder_input = outputs.z + (outputs.z_quantized - outputs.z).detach() |
| | reconstructions = self.decode(decoder_input, should_postprocess) |
| | return outputs.z, outputs.z_quantized, reconstructions |
| |
|
| | def encode(self, x: torch.Tensor, should_preprocess: bool = False) -> dict: |
| | if should_preprocess: |
| | x = self.preprocess_input(x) |
| | shape = x.shape |
| | x = x.view(-1, *shape[-3:]) |
| | z = self.encoder(x) |
| | z = self.pre_quant_conv(z) |
| | b, e, h, w = z.shape |
| | z_flattened = rearrange(z, 'b e h w -> (b h w) e') |
| | dist_to_embeddings = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + torch.sum(self.embedding.weight**2, dim=1) - 2 * torch.matmul(z_flattened, self.embedding.weight.t()) |
| |
|
| | tokens = dist_to_embeddings.argmin(dim=-1) |
| | z_q = rearrange(self.embedding(tokens), '(b h w) e -> b e h w', b=b, e=e, h=h, w=w).contiguous() |
| |
|
| | |
| | z = z.reshape(*shape[:-3], *z.shape[1:]) |
| | z_q = z_q.reshape(*shape[:-3], *z_q.shape[1:]) |
| | tokens = tokens.reshape(*shape[:-3], -1) |
| |
|
| | return { |
| | "z": z, |
| | "z_quantized": z_q, |
| | "tokens": tokens |
| | } |
| |
|
| | def decode(self, z_q: torch.Tensor, should_postprocess: bool = False) -> torch.Tensor: |
| | shape = z_q.shape |
| | z_q = z_q.view(-1, *shape[-3:]) |
| | z_q = self.post_quant_conv(z_q) |
| | rec = self.decoder(z_q) |
| | rec = rec.reshape(*shape[:-3], *rec.shape[1:]) |
| | if should_postprocess: |
| | rec = self.postprocess_output(rec) |
| | return rec |
| |
|
| | @torch.no_grad() |
| | def encode_decode(self, x: torch.Tensor, should_preprocess: bool = False, should_postprocess: bool = False) -> torch.Tensor: |
| | z_q = self.encode(x, should_preprocess).z_quantized |
| | return self.decode(z_q, should_postprocess) |
| |
|
| | def preprocess_input(self, x: torch.Tensor) -> torch.Tensor: |
| | """x is supposed to be channels first and in [0, 1]""" |
| | return x.mul(2).sub(1) |
| |
|
| | def postprocess_output(self, y: torch.Tensor) -> torch.Tensor: |
| | """y is supposed to be channels first and in [-1, 1]""" |
| | return y.add(1).div(2) |
| |
|