|
|
from dataclasses import dataclass |
|
|
import math |
|
|
from typing import Dict, Tuple |
|
|
|
|
|
from einops import rearrange |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from .models.convnet import FrameEncoder, FrameDecoder |
|
|
from .data import Batch |
|
|
from .models.tokenizer.quantizer import Quantizer, QuantizerOutput |
|
|
from .models.utils import init_weights, LossWithIntermediateLosses |
|
|
|
|
|
class Tokenizer(nn.Module): |
|
|
def __init__(self, config: dict) -> None: |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
self.latent_res = config["image_size"] // 2 ** sum(config["encoder_config"]["down"]) |
|
|
self.tokens_grid_res = int(math.sqrt(config["num_tokens"])) |
|
|
self.token_res = self.latent_res // self.tokens_grid_res |
|
|
|
|
|
self.encoder_act_emb = nn.Embedding(config["num_actions"], config["image_size"] ** 2) |
|
|
self.decoder_act_emb = nn.Embedding(config["num_actions"], config["decoder_act_channels"] * self.latent_res ** 2) |
|
|
|
|
|
self.quantizer = Quantizer( |
|
|
config["codebook_size"], config["codebook_dim"], |
|
|
input_dim=config["encoder_config"]["latent_dim"] * self.token_res ** 2, |
|
|
max_codebook_updates_with_revival=config["max_codebook_updates_with_revival"] |
|
|
) |
|
|
|
|
|
self.encoder = FrameEncoder(config["encoder_config"]) |
|
|
self.decoder = FrameDecoder(config["decoder_config"]) |
|
|
self.frame_cnn = FrameEncoder(config["frame_cnn_config"]) |
|
|
|
|
|
self.apply(init_weights) |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return "tokenizer" |
|
|
|
|
|
def forward(self, x1: torch.FloatTensor, a: torch.LongTensor, x2: torch.FloatTensor) -> QuantizerOutput: |
|
|
z = self.encode(x1, a, x2) |
|
|
z = rearrange(z, 'b t c (h k) (w l) -> b t (h w) (k l c)', h=self.tokens_grid_res, w=self.tokens_grid_res) |
|
|
|
|
|
return self.quantizer(z) |
|
|
|
|
|
def compute_loss(self, batch: Batch, **kwargs) -> Tuple[LossWithIntermediateLosses, Dict]: |
|
|
x1 = batch.observations[:, :-1] |
|
|
a = batch.actions[:, :-1] |
|
|
x2 = batch.observations[:, 1:] |
|
|
|
|
|
quantizer_outputs = self(x1, a, x2) |
|
|
|
|
|
r = self.decode(x1, a, rearrange(quantizer_outputs.q, 'b t (h w) (k l e) -> b t e (h k) (w l)', h=self.tokens_grid_res, k=self.token_res, l=self.token_res)) |
|
|
delta = (x2 - r) |
|
|
delta = delta[torch.logical_and(batch.mask_padding[:, 1:], batch.mask_padding[:, :-1])] |
|
|
|
|
|
losses = { |
|
|
**quantizer_outputs.loss, |
|
|
'reconstruction_loss_l1': 0.1 * torch.abs(delta).mean(), |
|
|
'reconstruction_loss_l2': delta.pow(2).mean(), |
|
|
'reconstruction_loss_l2_worst_pixel': 0.01 * rearrange(delta, 'b c h w -> b (c h w)').pow(2).max(dim=-1)[0].mean(), |
|
|
} |
|
|
|
|
|
return LossWithIntermediateLosses(**losses), quantizer_outputs.metrics |
|
|
|
|
|
def encode(self, x1: torch.FloatTensor, a: torch.LongTensor, x2: torch.FloatTensor) -> torch.FloatTensor: |
|
|
a_emb = rearrange(self.encoder_act_emb(a), 'b t (h w) -> b t 1 h w', h=x1.size(3)) |
|
|
encoder_input = torch.cat((x1, a_emb, x2), dim=2) |
|
|
z = self.encoder(encoder_input) |
|
|
|
|
|
return z |
|
|
|
|
|
def decode(self, x1: torch.FloatTensor, a: torch.LongTensor, q2: torch.FloatTensor, should_clamp: bool = False) -> torch.FloatTensor: |
|
|
x1_emb = self.frame_cnn(x1) |
|
|
a_emb = rearrange(self.decoder_act_emb(a), 'b t (c h w) -> b t c h w', c=self.config["decoder_act_channels"], h=x1_emb.size(3)) |
|
|
|
|
|
decoder_input = torch.cat((x1_emb, a_emb, q2), dim=2) |
|
|
|
|
|
r = self.decoder(decoder_input) |
|
|
r = torch.clamp(r, 0, 1).mul(255).round().div(255) if should_clamp else r |
|
|
|
|
|
return r |
|
|
|
|
|
@torch.no_grad() |
|
|
def encode_decode(self, x1: torch.FloatTensor, a: torch.LongTensor, x2: torch.FloatTensor) -> torch.Tensor: |
|
|
z = self.encode(x1, a, x2) |
|
|
z = rearrange(z, 'b t c (h k) (w l) -> b t (h w) (k l c)', k=self.token_res, l=self.token_res) |
|
|
q = rearrange(self.quantizer(z).q, 'b t (h w) (k l e) -> b t e (h k) (w l)', h=self.tokens_grid_res, k=self.token_res, l=self.token_res) |
|
|
r = self.decode(x1, a, q, should_clamp=True) |
|
|
|
|
|
return r |
|
|
|
|
|
@torch.no_grad() |
|
|
def burn_in(self, obs: torch.FloatTensor, act: torch.LongTensor) -> torch.LongTensor: |
|
|
assert obs.size(1) == act.size(1) + 1 |
|
|
quantizer_output = self(obs[:, :-1], act, obs[:, 1:]) |
|
|
|
|
|
return quantizer_output.tokens |
|
|
|