world_model_test / iris /src /tokenizer.py
ShaswatRobotics's picture
Update iris/src/tokenizer.py
bd89e7d verified
"""
Credits to https://github.com/CompVis/taming-transformers
"""
from typing import Tuple
import numpy as np
from einops import rearrange
import torch
import torch.nn as nn
from .models.lpips import LPIPS
from .models.nets import Encoder, Decoder
class Tokenizer(nn.Module):
def __init__(self, config: dict, with_lpips: bool = True) -> None:
super().__init__()
self.vocab_size = config["vocab_size"]
self.embed_dim = config["embed_dim"]
self.encoder = Encoder(config["encoder"])
self.decoder = Decoder(config["decoder"])
self.pre_quant_conv = torch.nn.Conv2d(self.encoder.config["z_channels"], self.embed_dim, 1)
self.embedding = nn.Embedding(self.vocab_size, self.embed_dim)
self.post_quant_conv = torch.nn.Conv2d(self.embed_dim, self.decoder.config["z_channels"], 1)
self.embedding.weight.data.uniform_(-1.0 / self.vocab_size, 1.0 / self.vocab_size)
self.lpips = LPIPS().eval() if with_lpips else None
def __repr__(self) -> str:
return "tokenizer"
def forward(self, x: torch.Tensor, should_preprocess: bool = False, should_postprocess: bool = False) -> Tuple[torch.Tensor]:
outputs = self.encode(x, should_preprocess)
decoder_input = outputs["z"] + (outputs["z_quantized"] - outputs["z"]).detach()
reconstructions = self.decode(decoder_input, should_postprocess)
return outputs["z"], outputs["z_quantized"], reconstructions
def encode(self, x: torch.Tensor, should_preprocess: bool = False) -> dict:
if should_preprocess:
x = self.preprocess_input(x)
shape = x.shape # (..., C, H, W)
x = x.view(-1, *shape[-3:])
z = self.encoder(x)
z = self.pre_quant_conv(z)
b, e, h, w = z.shape
z_flattened = rearrange(z, 'b e h w -> (b h w) e')
dist_to_embeddings = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + torch.sum(self.embedding.weight**2, dim=1) - 2 * torch.matmul(z_flattened, self.embedding.weight.t())
tokens = dist_to_embeddings.argmin(dim=-1)
z_q = rearrange(self.embedding(tokens), '(b h w) e -> b e h w', b=b, e=e, h=h, w=w).contiguous()
# Reshape to original
z = z.reshape(*shape[:-3], *z.shape[1:])
z_q = z_q.reshape(*shape[:-3], *z_q.shape[1:])
tokens = tokens.reshape(*shape[:-3], -1)
return {
"z": z,
"z_quantized": z_q,
"tokens": tokens
}
def decode(self, z_q: torch.Tensor, should_postprocess: bool = False) -> torch.Tensor:
shape = z_q.shape # (..., E, h, w)
z_q = z_q.view(-1, *shape[-3:])
z_q = self.post_quant_conv(z_q)
rec = self.decoder(z_q)
rec = rec.reshape(*shape[:-3], *rec.shape[1:])
if should_postprocess:
rec = self.postprocess_output(rec)
return rec
def decode_obs_tokens(self, obs_tokens, num_observations_tokens):
embedded_tokens = self.embedding(obs_tokens) # (B, K, E)
z = rearrange(embedded_tokens, 'b (h w) e -> b e h w', h=int(np.sqrt(num_observations_tokens)))
rec = self.decode(z, should_postprocess=True) # (B, C, H, W)
return torch.clamp(rec, 0, 1)
@torch.no_grad()
def encode_decode(self, x: torch.Tensor, should_preprocess: bool = False, should_postprocess: bool = False) -> torch.Tensor:
z_q = self.encode(x, should_preprocess).z_quantized
return self.decode(z_q, should_postprocess)
def preprocess_input(self, x: torch.Tensor) -> torch.Tensor:
"""x is supposed to be channels first and in [0, 1]"""
return x.mul(2).sub(1)
def postprocess_output(self, y: torch.Tensor) -> torch.Tensor:
"""y is supposed to be channels first and in [-1, 1]"""
return y.add(1).div(2)