File size: 4,400 Bytes
cd8454d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# Copyright (C) 2025. Huawei Technologies Co., Ltd. All Rights Reserved. (authors: Dehua Tao)
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import torch
import torch.nn as nn
import logging

Codeword = torch.FloatTensor
Indices = torch.FloatTensor


def round_ste(z):
    """Round with straight through gradients."""
    zhat = torch.round(z)
    return z + (zhat - z).detach()


class FSQ(nn.Module):
    """Quantizer."""

    def __init__(self, levels: list, eps: float = 1e-3, l2_norm: bool = False, batch_norm: bool = False):
        super().__init__()

        self._levels = levels
        self._eps = eps
        self.l2_norm = l2_norm
        self.batch_norm = batch_norm
        self._levels_np = torch.Tensor(levels)
        self._basis = torch.cat((torch.Tensor([1]), torch.cumprod(self._levels_np[:-1], dim=0)))
        self._implicit_codebook = self.indexes_to_codes(torch.arange(self.codebook_size))
        logging.info(f'levels: {levels}')

        if self.batch_norm:
            self.bn = nn.BatchNorm1d(self.num_dimensions, momentum=0.01, eps=1e-3)

    @property
    def num_dimensions(self) -> int:
        """Number of dimensions expected from inputs."""
        return len(self._levels)

    @property
    def codebook_size(self) -> int:
        """Size of the codebook."""
        return np.prod(self._levels)

    @property
    def codebook(self):
        """Returns the implicit codebook. Shape (prod(levels), num_dimensions)."""
        return self._implicit_codebook

    def bound(self, z: torch.FloatTensor) -> torch.FloatTensor:
        """Bound `z`, an array of shape (..., d)."""
        half_l = (self._levels_np - 1) * (1 - self._eps) / 2
        offset = torch.where(self._levels_np % 2 == 1, 0.0, 0.5)
        shift = torch.tan(offset / half_l)
        return torch.tanh(z + shift) * half_l - offset

    def quantize(self, z: torch.FloatTensor) -> Codeword:
        """Quanitzes z, returns quantized zhat, same shape as z."""
        quantized = round_ste(self.bound(z))

        # Renormalize to [-1, 1].
        half_width = torch.div(self._levels_np, 2, rounding_mode='floor')
        return quantized / half_width

    def _scale_and_shift(self, zhat_normalized):
        # Scale and shift to range [0, ..., L-1]
        half_width = torch.div(self._levels_np, 2, rounding_mode='floor')
        return (zhat_normalized * half_width) + half_width

    def _scale_and_shift_inverse(self, zhat):
        # Note that array(x) // 2 != tensor(x) // 2 when x is negative
        half_width = torch.div(self._levels_np, 2, rounding_mode='floor')
        return (zhat - half_width) / half_width

    def codes_to_indexes(self, zhat: Codeword) -> Indices:
        """Converts a `code` to an index in the codebook."""
        zhat = self._scale_and_shift(zhat)
        return torch.sum(zhat * self._basis, axis=-1)

    def indexes_to_codes(self, indices: Indices) -> Codeword:
        """Inverse of `indexes_to_codes`."""
        indices = indices.unsqueeze(-1)
        codes_non_centered = torch.remainder(
            torch.div(indices, self._basis, rounding_mode='floor'), self._levels_np
        )
        return self._scale_and_shift_inverse(codes_non_centered)

    def forward(self, z: torch.FloatTensor) -> Codeword:
        # Specify cuda index
        cuda_index = z.get_device()
        self._levels_np = self._levels_np.to(f'cuda:{cuda_index}')
        self._basis = self._basis.to(f'cuda:{cuda_index}')
        self._implicit_codebook = self._implicit_codebook.to(f'cuda:{cuda_index}')

        if self.l2_norm:
            z = nn.functional.normalize(z, p=2, dim=-1)

        if self.batch_norm:
            self.bn = self.bn.to(f'cuda:{cuda_index}')

            z = z.permute(0, 2, 1)
            z = self.bn(z)
            z = z.permute(0, 2, 1)

        zhat = self.quantize(z)

        return zhat