File size: 4,318 Bytes
23bc32f fb56df2 23bc32f fb56df2 23bc32f fb56df2 23bc32f fb56df2 23bc32f fb56df2 23bc32f fb56df2 23bc32f fb56df2 23bc32f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
from dataclasses import dataclass
import math
from typing import Dict, Tuple
from einops import rearrange
import torch
import torch.nn as nn
from .models.convnet import FrameEncoder, FrameDecoder
from .data import Batch
from .models.tokenizer.quantizer import Quantizer, QuantizerOutput
from .models.utils import init_weights, LossWithIntermediateLosses
class Tokenizer(nn.Module):
def __init__(self, config: dict) -> None:
super().__init__()
self.config = config
self.latent_res = config["image_size"] // 2 ** sum(config["encoder_config"]["down"])
self.tokens_grid_res = int(math.sqrt(config["num_tokens"]))
self.token_res = self.latent_res // self.tokens_grid_res
self.encoder_act_emb = nn.Embedding(config["num_actions"], config["image_size"] ** 2)
self.decoder_act_emb = nn.Embedding(config["num_actions"], config["decoder_act_channels"] * self.latent_res ** 2)
self.quantizer = Quantizer(
config["codebook_size"], config["codebook_dim"],
input_dim=config["encoder_config"]["latent_dim"] * self.token_res ** 2,
max_codebook_updates_with_revival=config["max_codebook_updates_with_revival"]
)
self.encoder = FrameEncoder(config["encoder_config"])
self.decoder = FrameDecoder(config["decoder_config"])
self.frame_cnn = FrameEncoder(config["frame_cnn_config"])
self.apply(init_weights)
def __repr__(self) -> str:
return "tokenizer"
def forward(self, x1: torch.FloatTensor, a: torch.LongTensor, x2: torch.FloatTensor) -> QuantizerOutput:
z = self.encode(x1, a, x2)
z = rearrange(z, 'b t c (h k) (w l) -> b t (h w) (k l c)', h=self.tokens_grid_res, w=self.tokens_grid_res)
return self.quantizer(z)
def compute_loss(self, batch: Batch, **kwargs) -> Tuple[LossWithIntermediateLosses, Dict]:
x1 = batch.observations[:, :-1]
a = batch.actions[:, :-1]
x2 = batch.observations[:, 1:]
quantizer_outputs = self(x1, a, x2)
r = self.decode(x1, a, rearrange(quantizer_outputs.q, 'b t (h w) (k l e) -> b t e (h k) (w l)', h=self.tokens_grid_res, k=self.token_res, l=self.token_res))
delta = (x2 - r)
delta = delta[torch.logical_and(batch.mask_padding[:, 1:], batch.mask_padding[:, :-1])]
losses = {
**quantizer_outputs.loss,
'reconstruction_loss_l1': 0.1 * torch.abs(delta).mean(),
'reconstruction_loss_l2': delta.pow(2).mean(),
'reconstruction_loss_l2_worst_pixel': 0.01 * rearrange(delta, 'b c h w -> b (c h w)').pow(2).max(dim=-1)[0].mean(),
}
return LossWithIntermediateLosses(**losses), quantizer_outputs.metrics
def encode(self, x1: torch.FloatTensor, a: torch.LongTensor, x2: torch.FloatTensor) -> torch.FloatTensor:
a_emb = rearrange(self.encoder_act_emb(a), 'b t (h w) -> b t 1 h w', h=x1.size(3))
encoder_input = torch.cat((x1, a_emb, x2), dim=2)
z = self.encoder(encoder_input)
return z
def decode(self, x1: torch.FloatTensor, a: torch.LongTensor, q2: torch.FloatTensor, should_clamp: bool = False) -> torch.FloatTensor:
x1_emb = self.frame_cnn(x1)
a_emb = rearrange(self.decoder_act_emb(a), 'b t (c h w) -> b t c h w', c=self.config["decoder_act_channels"], h=x1_emb.size(3))
decoder_input = torch.cat((x1_emb, a_emb, q2), dim=2)
r = self.decoder(decoder_input)
r = torch.clamp(r, 0, 1).mul(255).round().div(255) if should_clamp else r
return r
@torch.no_grad()
def encode_decode(self, x1: torch.FloatTensor, a: torch.LongTensor, x2: torch.FloatTensor) -> torch.Tensor:
z = self.encode(x1, a, x2)
z = rearrange(z, 'b t c (h k) (w l) -> b t (h w) (k l c)', k=self.token_res, l=self.token_res)
q = rearrange(self.quantizer(z).q, 'b t (h w) (k l e) -> b t e (h k) (w l)', h=self.tokens_grid_res, k=self.token_res, l=self.token_res)
r = self.decode(x1, a, q, should_clamp=True)
return r
@torch.no_grad()
def burn_in(self, obs: torch.FloatTensor, act: torch.LongTensor) -> torch.LongTensor:
assert obs.size(1) == act.size(1) + 1
quantizer_output = self(obs[:, :-1], act, obs[:, 1:])
return quantizer_output.tokens
|