File size: 3,559 Bytes
23bc32f
 
 
 
 
fb56df2
f9f6093
23bc32f
 
fb56df2
23bc32f
 
 
fb56df2
 
23bc32f
 
fb56df2
 
23bc32f
 
fb56df2
 
 
23bc32f
 
fb56df2
 
 
23bc32f
 
 
 
f77d622
23bc32f
 
 
 
 
 
 
 
 
 
 
 
 
 
fb56df2
23bc32f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9b48d6
 
 
 
 
 
 
 
 
 
 
20c9322
a9b48d6
 
 
 
23bc32f
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import math
from einops import rearrange
import torch
import torch.nn as nn

from .models.convnet import FrameEncoder, FrameDecoder
from .models.quantizer import Quantizer

class Tokenizer(nn.Module):
    def __init__(self, config: dict) -> None:
        super().__init__()
        self.config = config

        self.latent_res = config["image_size"] // 2 ** sum(config["encoder_config"]["down"])
        self.tokens_grid_res = int(math.sqrt(config["num_tokens"]))
        self.token_res = self.latent_res // self.tokens_grid_res

        self.encoder_act_emb = nn.Embedding(config["num_actions"], config["image_size"] ** 2)
        self.decoder_act_emb = nn.Embedding(config["num_actions"], config["decoder_act_channels"] * self.latent_res ** 2)

        self.quantizer = Quantizer(
            config["codebook_size"], config["codebook_dim"],
            input_dim=config["encoder_config"]["latent_dim"] * self.token_res ** 2,
            max_codebook_updates_with_revival=config["max_codebook_updates_with_revival"]
        )

        self.encoder = FrameEncoder(config["encoder_config"])
        self.decoder = FrameDecoder(config["decoder_config"])
        self.frame_cnn = FrameEncoder(config["frame_cnn_config"])

    def __repr__(self) -> str:
        return "tokenizer"

    def forward(self, x1: torch.FloatTensor, a: torch.LongTensor, x2: torch.FloatTensor) -> dict:
        z = self.encode(x1, a, x2)
        z = rearrange(z, 'b t c (h k) (w l) -> b t (h w) (k l c)', h=self.tokens_grid_res, w=self.tokens_grid_res)

        return self.quantizer(z)

    def encode(self, x1: torch.FloatTensor, a: torch.LongTensor, x2: torch.FloatTensor) -> torch.FloatTensor:
        a_emb = rearrange(self.encoder_act_emb(a), 'b t (h w) -> b t 1 h w', h=x1.size(3))
        encoder_input = torch.cat((x1, a_emb, x2), dim=2)
        z = self.encoder(encoder_input)

        return z

    def decode(self, x1: torch.FloatTensor, a: torch.LongTensor, q2: torch.FloatTensor, should_clamp: bool = False) -> torch.FloatTensor:
        x1_emb = self.frame_cnn(x1)
        a_emb = rearrange(self.decoder_act_emb(a), 'b t (c h w) -> b t c h w', c=self.config["decoder_act_channels"], h=x1_emb.size(3))

        decoder_input = torch.cat((x1_emb, a_emb, q2), dim=2)

        r = self.decoder(decoder_input)
        r = torch.clamp(r, 0, 1).mul(255).round().div(255) if should_clamp else r

        return r

    @torch.no_grad()
    def encode_decode(self, x1: torch.FloatTensor, a: torch.LongTensor, x2: torch.FloatTensor) -> torch.Tensor:
        z = self.encode(x1, a, x2)
        z = rearrange(z, 'b t c (h k) (w l) -> b t (h w) (k l c)', k=self.token_res, l=self.token_res)
        q = rearrange(self.quantizer(z).q, 'b t (h w) (k l e) -> b t e (h k) (w l)', h=self.tokens_grid_res, k=self.token_res, l=self.token_res)
        r = self.decode(x1, a, q, should_clamp=True)

        return r

    def embed_tokens(self, tokens):
        q = self.quantizer.embed_tokens(tokens)
        b, t, hw, kle = q.shape

        h = self.tokens_grid_res
        w = self.tokens_grid_res
        k = self.token_res
        l = self.token_res
        e = kle // (k * l)

        q = q.reshape(b, t, h, w, k, l, e)
        q = q.permute(0, 1, 6, 2, 4, 3, 5)
        q = q.reshape(b, t, e, h * k, w * l)

        return q

    @torch.no_grad()
    def burn_in(self, obs: torch.FloatTensor, act: torch.LongTensor) -> torch.LongTensor: 
        assert obs.size(1) == act.size(1) + 1
        quantizer_output = self(obs[:, :-1], act, obs[:, 1:])

        return quantizer_output.tokens