File size: 5,123 Bytes
2cba492
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# Finite Scalar Quantization: https://arxiv.org/abs/2309.15505

import torch
from torch import nn

from ..util import get_logger

logger = get_logger()


def round_ste(z: torch.Tensor) -> torch.Tensor:
    """Round with straight through gradients."""
    zhat = z.round()
    return z + (zhat - z).detach()


def get_entropy(prob: torch.Tensor, eps: float = 1e-10) -> torch.Tensor:
    return -torch.sum(prob * torch.log(prob + eps), dim=-1)


class FSQ(nn.Module):
    def __init__(self, levels: list[int]):
        super().__init__()
        self.levels = levels
        self.dim = len(levels)

        _levels = torch.tensor(levels, dtype=torch.long)
        self.register_buffer("_levels", _levels, persistent=False)
        _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.long)
        self.register_buffer("_basis", _basis, persistent=False)

    def bound(self, z: torch.Tensor, eps: float = 1e-3) -> torch.Tensor:
        """Bound `z`, an array of shape (..., d)."""
        half_l = (self._levels - 1) * (1 - eps) / 2
        offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
        shift = (offset / half_l).tan()
        return (z + shift).tanh() * half_l - offset

    def quantize(self, z: torch.Tensor) -> torch.Tensor:
        """Quantizes z, returns quantized zhat, same shape as z."""
        quantized = round_ste(self.bound(z))
        half_width = self._levels // 2  # Renormalize to [-1, 1].
        return quantized / half_width

    def _scale_and_shift(self, zhat_normalized: torch.Tensor) -> torch.Tensor:
        half_width = self._levels // 2
        return (zhat_normalized * half_width) + half_width

    def _scale_and_shift_inverse(self, zhat: torch.Tensor) -> torch.Tensor:
        half_width = self._levels // 2
        return (zhat - half_width) / half_width

    def codes_to_indices(self, zhat: torch.Tensor) -> torch.Tensor:
        """Converts a `code` to an index in the codebook."""
        # (B, T, C) -> (B, T)
        assert zhat.shape[-1] == len(self.levels)
        zhat = self._scale_and_shift(zhat)
        return (zhat * self._basis.to(torch.float64)).to(torch.long).sum(dim=-1)

    def indices_to_codes(self, indices: torch.Tensor) -> torch.Tensor:
        """Inverse of `codes_to_indices`."""
        # (B, T) -> (B, T, C)
        indices = indices.unsqueeze(-1)
        codes_non_centered = (indices // self._basis) % self._levels
        return self._scale_and_shift_inverse(codes_non_centered)

    def encode(self, z: torch.Tensor) -> torch.Tensor:
        z_q = self.quantize(z)
        indices = self.codes_to_indices(z_q)  # (B, T)
        return z_q, indices

    def decode(self, indices: torch.Tensor) -> torch.Tensor:
        z_q = self.indices_to_codes(indices)  # (B, T, C)
        return z_q

    def forward(self, z: torch.Tensor):
        z_q = self.quantize(z)
        indices = self.codes_to_indices(z_q)  # (B, T)
        return z_q, indices


class FiniteScalarQuantizer(nn.Module):
    def __init__(self, input_dim: int, output_dim: int, levels: list[int]) -> None:
        super().__init__()
        self.input_dim_ = input_dim
        self.output_dim_ = output_dim

        self.fsq = FSQ(levels)
        logger.debug(
            f"Finite Scalar Quantizer with levels: {levels}, input_dim: {input_dim}, output_dim: {output_dim}, codebook_size: {self.all_codebook_size}"
        )

        self.proj_in = nn.Linear(input_dim, len(levels)) if len(levels) != input_dim else nn.Identity()
        self.proj_out = nn.Linear(len(levels), output_dim) if len(levels) != output_dim else nn.Identity()

    def build_codebook(self) -> None:
        pass

    @property
    def output_dim(self) -> int:
        return self.output_dim_

    @property
    def all_codebook_size(self) -> int:
        size = 1
        for level in self.fsq.levels:
            size *= level
        return size

    def forward(self, z: torch.Tensor) -> tuple[torch.Tensor, dict]:
        latent = self.proj_in(z)  # Latent projected by proj_in
        quantized_latent, indices = self.fsq(latent)  # Quantized latent before proj_out
        z_q = self.proj_out(quantized_latent)

        # Compute perplexity from used indices distribution
        flat_indices = indices.view(-1)
        unique_indices, counts = torch.unique(flat_indices, return_counts=True)
        used_indices_probs = counts.float() / flat_indices.numel()
        entropy = get_entropy(used_indices_probs)
        perplexity = torch.exp(entropy)

        info_dict = {
            "latent": latent,
            "quantized_latent": quantized_latent,
            "indices": indices,
            "perplexity": perplexity,
        }
        return z_q, info_dict

    def encode(self, z: torch.Tensor, skip_proj: bool = False) -> tuple[torch.Tensor, torch.Tensor]:
        z = self.proj_in(z)
        z_q, indices = self.fsq.encode(z)
        if not skip_proj:
            z_q = self.proj_out(z_q)
        return z_q, indices

    def decode(self, indices: torch.Tensor) -> torch.Tensor:
        z_q = self.fsq.decode(indices)
        z_q = self.proj_out(z_q)
        return z_q