File size: 11,334 Bytes
f233443
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import warnings
from collections import OrderedDict

import pytest

import torch
from tests.test_utils import assert_expected, assert_expected_namedtuple, set_rng_seed
from torch import nn, tensor
from torchmultimodal.modules.layers.codebook import Codebook


@pytest.fixture(autouse=True)
def random_seed():
    set_rng_seed(4)


@pytest.fixture
def num_embeddings():
    return 4


@pytest.fixture
def embedding_dim():
    return 5


@pytest.fixture
def encoded():
    # This is 2x5x3
    encoded = tensor(
        [
            [
                [-1.0, 0.0, 1.0],
                [2.0, 1.0, 0.0],
                [0.0, -1.0, -1.0],
                [0.0, 2.0, -1.0],
                [-2.0, -1.0, 1.0],
            ],
            [
                [2.0, 2.0, -1.0],
                [1.0, -1.0, -2.0],
                [0.0, 0.0, 0.0],
                [1.0, 2.0, 1.0],
                [1.0, 0.0, 0.0],
            ],
        ]
    )
    encoded.requires_grad_()

    return encoded


@pytest.fixture
def embedding_weights():
    # This is 4x5
    return tensor(
        [
            [1.0, 0.0, -1.0, -1.0, 2.0],
            [2.0, -2.0, 0.0, 0.0, 1.0],
            [2.0, 1.0, 0.0, 1.0, 1.0],
            [-1.0, -2.0, 0.0, 2.0, 0.0],
        ]
    )


@pytest.fixture
def input_tensor_flat():
    # This is 4x3
    return tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])


@pytest.fixture
def codebook(num_embeddings, embedding_dim):
    return Codebook(
        num_embeddings=num_embeddings,
        embedding_dim=embedding_dim,
        decay=0.3,
    )


class TestCodebook:
    def test_quantized_output(self, codebook, embedding_weights, encoded):
        codebook.embedding = embedding_weights
        codebook._is_embedding_init = True
        actual = codebook(encoded)

        # This is shape (2,5,3)
        expected_quantized = tensor(
            [
                [
                    [2.0, 2.0, 1.0],
                    [1.0, 1.0, 0.0],
                    [0.0, 0.0, -1.0],
                    [1.0, 1.0, -1.0],
                    [1.0, 1.0, 2.0],
                ],
                [
                    [2.0, 2.0, -1.0],
                    [1.0, -2.0, -2.0],
                    [0.0, 0.0, 0.0],
                    [1.0, 0.0, 2.0],
                    [1.0, 1.0, 0.0],
                ],
            ]
        )
        expected_quantized_flat = (
            expected_quantized.permute(0, 2, 1).contiguous().view(-1, 5)
        )

        expected = {
            "encoded_flat": encoded.permute(0, 2, 1).contiguous().view(-1, 5),
            "quantized_flat": expected_quantized_flat,
            "codebook_indices": tensor([[2.0, 2.0, 0.0], [2.0, 1.0, 3.0]]).type(
                torch.LongTensor
            ),
            "quantized": expected_quantized,
        }

        assert_expected_namedtuple(actual, expected)

    def test_preprocess(self, codebook, encoded):
        encoded_flat, permuted_shape = codebook._preprocess(encoded)

        expected_flat_shape = torch.tensor([6, 5])
        expected_permuted_shape = torch.tensor([2, 3, 5])

        actual_flat_shape = torch.tensor(encoded_flat.shape)
        actual_permuted_shape = torch.tensor(permuted_shape)

        assert_expected(actual_flat_shape, expected_flat_shape)

        assert_expected(actual_permuted_shape, expected_permuted_shape)

    def test_preprocess_channel_dim_assertion(self, codebook, encoded):
        with pytest.raises(ValueError):
            codebook._preprocess(encoded[:, :4, :])

    def test_postprocess(self, codebook, input_tensor_flat):
        quantized = codebook._postprocess(input_tensor_flat, torch.Size([2, 2, 3]))
        actual_quantized_shape = torch.tensor(quantized.shape)
        expected_quantized_shape = torch.tensor([2, 3, 2])

        assert_expected(actual_quantized_shape, expected_quantized_shape)

    def test_init_embedding(self, codebook, encoded, num_embeddings):
        assert (
            not codebook._is_embedding_init
        ), "embedding init flag not False initially"

        encoded_flat, _ = codebook._preprocess(encoded)
        codebook._init_embedding(encoded_flat)

        assert codebook._is_embedding_init, "embedding init flag not True after init"

        actual_weight = codebook.embedding
        expected_weight = tensor(
            [
                [2.0, -1.0, 0.0, 2.0, 0.0],
                [2.0, 1.0, 0.0, 1.0, 1.0],
                [0.0, 1.0, -1.0, 2.0, -1.0],
                [1.0, 0.0, -1.0, -1.0, 1.0],
            ]
        )
        assert_expected(actual_weight, expected_weight)

        actual_code_avg = codebook.code_avg
        expected_code_avg = actual_weight
        assert_expected(actual_code_avg, expected_code_avg)

        actual_code_usage = codebook.code_usage
        expected_code_usage = torch.ones(num_embeddings)
        assert_expected(actual_code_usage, expected_code_usage)

    def test_ema_update_embedding(self, codebook, encoded):
        encoded_flat, _ = codebook._preprocess(encoded)
        codebook._init_embedding(encoded_flat)
        distances = torch.cdist(encoded_flat, codebook.embedding, p=2.0) ** 2
        codebook_indices = torch.argmin(distances, dim=1)
        codebook._ema_update_embedding(encoded_flat, codebook_indices)

        actual_weight = codebook.embedding
        expected_weight = tensor(
            [
                [0.7647, -1.4118, 0.0000, 1.5882, 0.0000],
                [2.0000, 1.0000, 0.0000, 1.0000, 1.0000],
                [-0.4118, 1.4118, -0.5882, 1.1765, -1.4118],
                [1.0000, 0.0000, -1.0000, -1.0000, 1.0000],
            ]
        )
        assert_expected(actual_weight, expected_weight, rtol=0.0, atol=1e-4)

        actual_code_avg = codebook.code_avg
        expected_code_avg = tensor(
            [
                [1.3000, -2.4000, 0.0000, 2.7000, 0.0000],
                [2.0000, 1.0000, 0.0000, 1.0000, 1.0000],
                [-0.7000, 2.4000, -1.0000, 2.0000, -2.4000],
                [1.0000, 0.0000, -1.0000, -1.0000, 1.0000],
            ]
        )
        assert_expected(actual_code_avg, expected_code_avg, rtol=0.0, atol=1e-4)

        actual_code_usage = codebook.code_usage
        expected_code_usage = tensor([1.7000, 1.0000, 1.7000, 1.0000])
        assert_expected(actual_code_usage, expected_code_usage, rtol=0.0, atol=1e-4)

    def test_register_buffer_tensors(self, codebook, encoded):
        out = codebook(encoded)
        out.quantized.sum().backward()

        msg_has_grad = "tensor assigned to buffer but accumulated grad"
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            assert not codebook.code_avg.grad, msg_has_grad
            assert not codebook.code_usage.grad, msg_has_grad
            assert not codebook.embedding.grad, msg_has_grad

        assert not list(
            codebook.parameters()
        ), "buffer variables incorrectly assigned as params"

    def test_init_embedding_smaller_encoded(self, codebook, encoded):
        encoded_small = encoded[:1, :, :2]
        encoded_small_flat, _ = codebook._preprocess(encoded_small)
        codebook._init_embedding(encoded_small_flat)
        embed = codebook.embedding
        # Check for each embedding vector if there is one equal encoded vector + noise
        for emb in embed:
            assert any(
                [
                    torch.isclose(emb, enc, rtol=0, atol=0.01).all()
                    for enc in encoded_small_flat
                ]
            ), "embedding initialized from encoder output incorrectly"

    def test_codebook_restart(self, codebook, encoded):
        encoded_flat, _ = codebook._preprocess(encoded)
        # First init and diversify embedding
        codebook._init_embedding(encoded_flat)
        # Use only embedding vector at index = 1 and force restarts.
        # Slightly modify encoded_flat to make sure vectors restart to something new
        encoded_flat_noise = encoded_flat + torch.randn_like(encoded_flat)
        codebook_indices_low_usage = torch.ones(encoded_flat.shape[0], dtype=torch.long)
        codebook._ema_update_embedding(encoded_flat_noise, codebook_indices_low_usage)

        # Check if embedding contains restarts
        for i, emb in enumerate(codebook.embedding):
            # We used only emb vector with index = 1, so check it was not restarted
            if i == 1:
                assert_expected(
                    emb,
                    codebook.code_avg[1] / codebook.code_usage[1],
                    rtol=0,
                    atol=1e-4,
                )
            # Compare each embedding vector to each encoded vector.
            # If at least one match, then restart happened.
            else:
                assert any(
                    [
                        torch.isclose(emb, enc, rtol=0, atol=1e-4).all()
                        for enc in encoded_flat_noise
                    ]
                ), "embedding restarted from encoder output incorrectly"

    def test_load_state_dict(self):
        state_dict = OrderedDict(
            [
                ("linear.weight", tensor([[1.0]])),
                ("linear.bias", tensor([2.0])),
                ("codebook.embedding", tensor([[3.0]])),
                ("codebook.code_usage", tensor([4.0])),
                ("codebook.code_avg", tensor([[5.0]])),
            ]
        )

        class DummyModel(nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = nn.Linear(1, 1)
                self.codebook = Codebook(1, 1)

        model = DummyModel()
        assert not model.codebook._is_embedding_init
        model.load_state_dict(state_dict)
        assert model.codebook._is_embedding_init

        actual = model.codebook.embedding
        expected = state_dict["codebook.embedding"]
        assert_expected(actual, expected)

        actual = model.codebook.code_usage
        expected = state_dict["codebook.code_usage"]
        assert_expected(actual, expected)

        actual = model.codebook.code_avg
        expected = state_dict["codebook.code_avg"]
        assert_expected(actual, expected)

    def test_lookup(self, codebook, embedding_weights):
        codebook.embedding = embedding_weights
        indices_flat = tensor([[0, 1]])  # (b, seq_len)
        indices_shaped = tensor([[[0, 1], [2, 3]]])  # (b, shape)
        actual_quantized_flat = codebook.lookup(indices_flat)
        actual_quantized = codebook.lookup(indices_shaped)
        expected_quantized_flat = tensor(
            [[[1.0, 0.0, -1.0, -1.0, 2.0], [2.0, -2.0, 0.0, 0.0, 1.0]]]
        )
        expected_quantized = tensor(
            [
                [
                    [[1.0, 0.0, -1.0, -1.0, 2.0], [2.0, -2.0, 0.0, 0.0, 1.0]],
                    [[2.0, 1.0, 0.0, 1.0, 1.0], [-1.0, -2.0, 0.0, 2.0, 0.0]],
                ]
            ]
        )
        assert_expected(
            actual_quantized_flat, expected_quantized_flat, rtol=0.0, atol=1e-4
        )
        assert_expected(actual_quantized, expected_quantized, rtol=0.0, atol=1e-4)