File size: 4,318 Bytes
23bc32f
 
 
 
 
 
 
 
fb56df2
23bc32f
 
 
 
 
fb56df2
23bc32f
 
 
fb56df2
 
23bc32f
 
fb56df2
 
23bc32f
 
fb56df2
 
 
23bc32f
 
fb56df2
 
 
23bc32f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb56df2
23bc32f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from dataclasses import dataclass
import math
from typing import Dict, Tuple

from einops import rearrange
import torch
import torch.nn as nn

from .models.convnet import FrameEncoder, FrameDecoder
from .data import Batch
from .models.tokenizer.quantizer import Quantizer, QuantizerOutput
from .models.utils import init_weights, LossWithIntermediateLosses

class Tokenizer(nn.Module):
    def __init__(self, config: dict) -> None:
        super().__init__()
        self.config = config

        self.latent_res = config["image_size"] // 2 ** sum(config["encoder_config"]["down"])
        self.tokens_grid_res = int(math.sqrt(config["num_tokens"]))
        self.token_res = self.latent_res // self.tokens_grid_res

        self.encoder_act_emb = nn.Embedding(config["num_actions"], config["image_size"] ** 2)
        self.decoder_act_emb = nn.Embedding(config["num_actions"], config["decoder_act_channels"] * self.latent_res ** 2)

        self.quantizer = Quantizer(
            config["codebook_size"], config["codebook_dim"],
            input_dim=config["encoder_config"]["latent_dim"] * self.token_res ** 2,
            max_codebook_updates_with_revival=config["max_codebook_updates_with_revival"]
        )

        self.encoder = FrameEncoder(config["encoder_config"])
        self.decoder = FrameDecoder(config["decoder_config"])
        self.frame_cnn = FrameEncoder(config["frame_cnn_config"])

        self.apply(init_weights)

    def __repr__(self) -> str:
        return "tokenizer"

    def forward(self, x1: torch.FloatTensor, a: torch.LongTensor, x2: torch.FloatTensor) -> QuantizerOutput:
        z = self.encode(x1, a, x2)
        z = rearrange(z, 'b t c (h k) (w l) -> b t (h w) (k l c)', h=self.tokens_grid_res, w=self.tokens_grid_res)

        return self.quantizer(z)

    def compute_loss(self, batch: Batch, **kwargs) -> Tuple[LossWithIntermediateLosses, Dict]:
        x1 = batch.observations[:, :-1]
        a = batch.actions[:, :-1]
        x2 = batch.observations[:, 1:]

        quantizer_outputs = self(x1, a, x2)

        r = self.decode(x1, a, rearrange(quantizer_outputs.q, 'b t (h w) (k l e) -> b t e (h k) (w l)', h=self.tokens_grid_res, k=self.token_res, l=self.token_res))
        delta = (x2 - r)
        delta = delta[torch.logical_and(batch.mask_padding[:, 1:], batch.mask_padding[:, :-1])]

        losses = {
            **quantizer_outputs.loss,
            'reconstruction_loss_l1': 0.1 * torch.abs(delta).mean(),
            'reconstruction_loss_l2': delta.pow(2).mean(),
            'reconstruction_loss_l2_worst_pixel': 0.01 * rearrange(delta, 'b c h w -> b (c h w)').pow(2).max(dim=-1)[0].mean(),
        }

        return LossWithIntermediateLosses(**losses), quantizer_outputs.metrics

    def encode(self, x1: torch.FloatTensor, a: torch.LongTensor, x2: torch.FloatTensor) -> torch.FloatTensor:
        a_emb = rearrange(self.encoder_act_emb(a), 'b t (h w) -> b t 1 h w', h=x1.size(3))
        encoder_input = torch.cat((x1, a_emb, x2), dim=2)
        z = self.encoder(encoder_input)

        return z

    def decode(self, x1: torch.FloatTensor, a: torch.LongTensor, q2: torch.FloatTensor, should_clamp: bool = False) -> torch.FloatTensor:
        x1_emb = self.frame_cnn(x1)
        a_emb = rearrange(self.decoder_act_emb(a), 'b t (c h w) -> b t c h w', c=self.config["decoder_act_channels"], h=x1_emb.size(3))

        decoder_input = torch.cat((x1_emb, a_emb, q2), dim=2)

        r = self.decoder(decoder_input)
        r = torch.clamp(r, 0, 1).mul(255).round().div(255) if should_clamp else r

        return r

    @torch.no_grad()
    def encode_decode(self, x1: torch.FloatTensor, a: torch.LongTensor, x2: torch.FloatTensor) -> torch.Tensor:
        z = self.encode(x1, a, x2)
        z = rearrange(z, 'b t c (h k) (w l) -> b t (h w) (k l c)', k=self.token_res, l=self.token_res)
        q = rearrange(self.quantizer(z).q, 'b t (h w) (k l e) -> b t e (h k) (w l)', h=self.tokens_grid_res, k=self.token_res, l=self.token_res)
        r = self.decode(x1, a, q, should_clamp=True)

        return r

    @torch.no_grad()
    def burn_in(self, obs: torch.FloatTensor, act: torch.LongTensor) -> torch.LongTensor: 
        assert obs.size(1) == act.size(1) + 1
        quantizer_output = self(obs[:, :-1], act, obs[:, 1:])

        return quantizer_output.tokens