|
|
""" |
|
|
Credits to https://github.com/CompVis/taming-transformers |
|
|
""" |
|
|
|
|
|
from typing import Tuple |
|
|
|
|
|
import numpy as np |
|
|
from einops import rearrange |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from .models.lpips import LPIPS |
|
|
from .models.nets import Encoder, Decoder |
|
|
|
|
|
class Tokenizer(nn.Module): |
|
|
def __init__(self, config: dict, with_lpips: bool = True) -> None: |
|
|
super().__init__() |
|
|
self.vocab_size = config["vocab_size"] |
|
|
self.embed_dim = config["embed_dim"] |
|
|
self.encoder = Encoder(config["encoder"]) |
|
|
self.decoder = Decoder(config["decoder"]) |
|
|
self.pre_quant_conv = torch.nn.Conv2d(self.encoder.config["z_channels"], self.embed_dim, 1) |
|
|
self.embedding = nn.Embedding(self.vocab_size, self.embed_dim) |
|
|
self.post_quant_conv = torch.nn.Conv2d(self.embed_dim, self.decoder.config["z_channels"], 1) |
|
|
self.embedding.weight.data.uniform_(-1.0 / self.vocab_size, 1.0 / self.vocab_size) |
|
|
self.lpips = LPIPS().eval() if with_lpips else None |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return "tokenizer" |
|
|
|
|
|
def forward(self, x: torch.Tensor, should_preprocess: bool = False, should_postprocess: bool = False) -> Tuple[torch.Tensor]: |
|
|
outputs = self.encode(x, should_preprocess) |
|
|
decoder_input = outputs["z"] + (outputs["z_quantized"] - outputs["z"]).detach() |
|
|
reconstructions = self.decode(decoder_input, should_postprocess) |
|
|
return outputs["z"], outputs["z_quantized"], reconstructions |
|
|
|
|
|
def encode(self, x: torch.Tensor, should_preprocess: bool = False) -> dict: |
|
|
if should_preprocess: |
|
|
x = self.preprocess_input(x) |
|
|
shape = x.shape |
|
|
x = x.view(-1, *shape[-3:]) |
|
|
z = self.encoder(x) |
|
|
z = self.pre_quant_conv(z) |
|
|
b, e, h, w = z.shape |
|
|
z_flattened = rearrange(z, 'b e h w -> (b h w) e') |
|
|
dist_to_embeddings = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + torch.sum(self.embedding.weight**2, dim=1) - 2 * torch.matmul(z_flattened, self.embedding.weight.t()) |
|
|
|
|
|
tokens = dist_to_embeddings.argmin(dim=-1) |
|
|
z_q = rearrange(self.embedding(tokens), '(b h w) e -> b e h w', b=b, e=e, h=h, w=w).contiguous() |
|
|
|
|
|
|
|
|
z = z.reshape(*shape[:-3], *z.shape[1:]) |
|
|
z_q = z_q.reshape(*shape[:-3], *z_q.shape[1:]) |
|
|
tokens = tokens.reshape(*shape[:-3], -1) |
|
|
|
|
|
return { |
|
|
"z": z, |
|
|
"z_quantized": z_q, |
|
|
"tokens": tokens |
|
|
} |
|
|
|
|
|
def decode(self, z_q: torch.Tensor, should_postprocess: bool = False) -> torch.Tensor: |
|
|
shape = z_q.shape |
|
|
z_q = z_q.view(-1, *shape[-3:]) |
|
|
z_q = self.post_quant_conv(z_q) |
|
|
rec = self.decoder(z_q) |
|
|
rec = rec.reshape(*shape[:-3], *rec.shape[1:]) |
|
|
if should_postprocess: |
|
|
rec = self.postprocess_output(rec) |
|
|
return rec |
|
|
|
|
|
def decode_obs_tokens(self, obs_tokens, num_observations_tokens): |
|
|
embedded_tokens = self.embedding(obs_tokens) |
|
|
z = rearrange(embedded_tokens, 'b (h w) e -> b e h w', h=int(np.sqrt(num_observations_tokens))) |
|
|
rec = self.decode(z, should_postprocess=True) |
|
|
return torch.clamp(rec, 0, 1) |
|
|
|
|
|
@torch.no_grad() |
|
|
def encode_decode(self, x: torch.Tensor, should_preprocess: bool = False, should_postprocess: bool = False) -> torch.Tensor: |
|
|
z_q = self.encode(x, should_preprocess).z_quantized |
|
|
return self.decode(z_q, should_postprocess) |
|
|
|
|
|
def preprocess_input(self, x: torch.Tensor) -> torch.Tensor: |
|
|
"""x is supposed to be channels first and in [0, 1]""" |
|
|
return x.mul(2).sub(1) |
|
|
|
|
|
def postprocess_output(self, y: torch.Tensor) -> torch.Tensor: |
|
|
"""y is supposed to be channels first and in [-1, 1]""" |
|
|
return y.add(1).div(2) |
|
|
|