tuandunghcmut's picture
Add files using upload-large-folder tool
f233443 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from collections import OrderedDict
import pytest
import torch
from tests.test_utils import assert_expected, assert_expected_namedtuple, set_rng_seed
from torch import nn, tensor
from torchmultimodal.modules.layers.codebook import Codebook
@pytest.fixture(autouse=True)
def random_seed():
set_rng_seed(4)
@pytest.fixture
def num_embeddings():
return 4
@pytest.fixture
def embedding_dim():
return 5
@pytest.fixture
def encoded():
# This is 2x5x3
encoded = tensor(
[
[
[-1.0, 0.0, 1.0],
[2.0, 1.0, 0.0],
[0.0, -1.0, -1.0],
[0.0, 2.0, -1.0],
[-2.0, -1.0, 1.0],
],
[
[2.0, 2.0, -1.0],
[1.0, -1.0, -2.0],
[0.0, 0.0, 0.0],
[1.0, 2.0, 1.0],
[1.0, 0.0, 0.0],
],
]
)
encoded.requires_grad_()
return encoded
@pytest.fixture
def embedding_weights():
# This is 4x5
return tensor(
[
[1.0, 0.0, -1.0, -1.0, 2.0],
[2.0, -2.0, 0.0, 0.0, 1.0],
[2.0, 1.0, 0.0, 1.0, 1.0],
[-1.0, -2.0, 0.0, 2.0, 0.0],
]
)
@pytest.fixture
def input_tensor_flat():
# This is 4x3
return tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])
@pytest.fixture
def codebook(num_embeddings, embedding_dim):
return Codebook(
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
decay=0.3,
)
class TestCodebook:
def test_quantized_output(self, codebook, embedding_weights, encoded):
codebook.embedding = embedding_weights
codebook._is_embedding_init = True
actual = codebook(encoded)
# This is shape (2,5,3)
expected_quantized = tensor(
[
[
[2.0, 2.0, 1.0],
[1.0, 1.0, 0.0],
[0.0, 0.0, -1.0],
[1.0, 1.0, -1.0],
[1.0, 1.0, 2.0],
],
[
[2.0, 2.0, -1.0],
[1.0, -2.0, -2.0],
[0.0, 0.0, 0.0],
[1.0, 0.0, 2.0],
[1.0, 1.0, 0.0],
],
]
)
expected_quantized_flat = (
expected_quantized.permute(0, 2, 1).contiguous().view(-1, 5)
)
expected = {
"encoded_flat": encoded.permute(0, 2, 1).contiguous().view(-1, 5),
"quantized_flat": expected_quantized_flat,
"codebook_indices": tensor([[2.0, 2.0, 0.0], [2.0, 1.0, 3.0]]).type(
torch.LongTensor
),
"quantized": expected_quantized,
}
assert_expected_namedtuple(actual, expected)
def test_preprocess(self, codebook, encoded):
encoded_flat, permuted_shape = codebook._preprocess(encoded)
expected_flat_shape = torch.tensor([6, 5])
expected_permuted_shape = torch.tensor([2, 3, 5])
actual_flat_shape = torch.tensor(encoded_flat.shape)
actual_permuted_shape = torch.tensor(permuted_shape)
assert_expected(actual_flat_shape, expected_flat_shape)
assert_expected(actual_permuted_shape, expected_permuted_shape)
def test_preprocess_channel_dim_assertion(self, codebook, encoded):
with pytest.raises(ValueError):
codebook._preprocess(encoded[:, :4, :])
def test_postprocess(self, codebook, input_tensor_flat):
quantized = codebook._postprocess(input_tensor_flat, torch.Size([2, 2, 3]))
actual_quantized_shape = torch.tensor(quantized.shape)
expected_quantized_shape = torch.tensor([2, 3, 2])
assert_expected(actual_quantized_shape, expected_quantized_shape)
def test_init_embedding(self, codebook, encoded, num_embeddings):
assert (
not codebook._is_embedding_init
), "embedding init flag not False initially"
encoded_flat, _ = codebook._preprocess(encoded)
codebook._init_embedding(encoded_flat)
assert codebook._is_embedding_init, "embedding init flag not True after init"
actual_weight = codebook.embedding
expected_weight = tensor(
[
[2.0, -1.0, 0.0, 2.0, 0.0],
[2.0, 1.0, 0.0, 1.0, 1.0],
[0.0, 1.0, -1.0, 2.0, -1.0],
[1.0, 0.0, -1.0, -1.0, 1.0],
]
)
assert_expected(actual_weight, expected_weight)
actual_code_avg = codebook.code_avg
expected_code_avg = actual_weight
assert_expected(actual_code_avg, expected_code_avg)
actual_code_usage = codebook.code_usage
expected_code_usage = torch.ones(num_embeddings)
assert_expected(actual_code_usage, expected_code_usage)
def test_ema_update_embedding(self, codebook, encoded):
encoded_flat, _ = codebook._preprocess(encoded)
codebook._init_embedding(encoded_flat)
distances = torch.cdist(encoded_flat, codebook.embedding, p=2.0) ** 2
codebook_indices = torch.argmin(distances, dim=1)
codebook._ema_update_embedding(encoded_flat, codebook_indices)
actual_weight = codebook.embedding
expected_weight = tensor(
[
[0.7647, -1.4118, 0.0000, 1.5882, 0.0000],
[2.0000, 1.0000, 0.0000, 1.0000, 1.0000],
[-0.4118, 1.4118, -0.5882, 1.1765, -1.4118],
[1.0000, 0.0000, -1.0000, -1.0000, 1.0000],
]
)
assert_expected(actual_weight, expected_weight, rtol=0.0, atol=1e-4)
actual_code_avg = codebook.code_avg
expected_code_avg = tensor(
[
[1.3000, -2.4000, 0.0000, 2.7000, 0.0000],
[2.0000, 1.0000, 0.0000, 1.0000, 1.0000],
[-0.7000, 2.4000, -1.0000, 2.0000, -2.4000],
[1.0000, 0.0000, -1.0000, -1.0000, 1.0000],
]
)
assert_expected(actual_code_avg, expected_code_avg, rtol=0.0, atol=1e-4)
actual_code_usage = codebook.code_usage
expected_code_usage = tensor([1.7000, 1.0000, 1.7000, 1.0000])
assert_expected(actual_code_usage, expected_code_usage, rtol=0.0, atol=1e-4)
def test_register_buffer_tensors(self, codebook, encoded):
out = codebook(encoded)
out.quantized.sum().backward()
msg_has_grad = "tensor assigned to buffer but accumulated grad"
with warnings.catch_warnings():
warnings.simplefilter("ignore")
assert not codebook.code_avg.grad, msg_has_grad
assert not codebook.code_usage.grad, msg_has_grad
assert not codebook.embedding.grad, msg_has_grad
assert not list(
codebook.parameters()
), "buffer variables incorrectly assigned as params"
def test_init_embedding_smaller_encoded(self, codebook, encoded):
encoded_small = encoded[:1, :, :2]
encoded_small_flat, _ = codebook._preprocess(encoded_small)
codebook._init_embedding(encoded_small_flat)
embed = codebook.embedding
# Check for each embedding vector if there is one equal encoded vector + noise
for emb in embed:
assert any(
[
torch.isclose(emb, enc, rtol=0, atol=0.01).all()
for enc in encoded_small_flat
]
), "embedding initialized from encoder output incorrectly"
def test_codebook_restart(self, codebook, encoded):
encoded_flat, _ = codebook._preprocess(encoded)
# First init and diversify embedding
codebook._init_embedding(encoded_flat)
# Use only embedding vector at index = 1 and force restarts.
# Slightly modify encoded_flat to make sure vectors restart to something new
encoded_flat_noise = encoded_flat + torch.randn_like(encoded_flat)
codebook_indices_low_usage = torch.ones(encoded_flat.shape[0], dtype=torch.long)
codebook._ema_update_embedding(encoded_flat_noise, codebook_indices_low_usage)
# Check if embedding contains restarts
for i, emb in enumerate(codebook.embedding):
# We used only emb vector with index = 1, so check it was not restarted
if i == 1:
assert_expected(
emb,
codebook.code_avg[1] / codebook.code_usage[1],
rtol=0,
atol=1e-4,
)
# Compare each embedding vector to each encoded vector.
# If at least one match, then restart happened.
else:
assert any(
[
torch.isclose(emb, enc, rtol=0, atol=1e-4).all()
for enc in encoded_flat_noise
]
), "embedding restarted from encoder output incorrectly"
def test_load_state_dict(self):
state_dict = OrderedDict(
[
("linear.weight", tensor([[1.0]])),
("linear.bias", tensor([2.0])),
("codebook.embedding", tensor([[3.0]])),
("codebook.code_usage", tensor([4.0])),
("codebook.code_avg", tensor([[5.0]])),
]
)
class DummyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(1, 1)
self.codebook = Codebook(1, 1)
model = DummyModel()
assert not model.codebook._is_embedding_init
model.load_state_dict(state_dict)
assert model.codebook._is_embedding_init
actual = model.codebook.embedding
expected = state_dict["codebook.embedding"]
assert_expected(actual, expected)
actual = model.codebook.code_usage
expected = state_dict["codebook.code_usage"]
assert_expected(actual, expected)
actual = model.codebook.code_avg
expected = state_dict["codebook.code_avg"]
assert_expected(actual, expected)
def test_lookup(self, codebook, embedding_weights):
codebook.embedding = embedding_weights
indices_flat = tensor([[0, 1]]) # (b, seq_len)
indices_shaped = tensor([[[0, 1], [2, 3]]]) # (b, shape)
actual_quantized_flat = codebook.lookup(indices_flat)
actual_quantized = codebook.lookup(indices_shaped)
expected_quantized_flat = tensor(
[[[1.0, 0.0, -1.0, -1.0, 2.0], [2.0, -2.0, 0.0, 0.0, 1.0]]]
)
expected_quantized = tensor(
[
[
[[1.0, 0.0, -1.0, -1.0, 2.0], [2.0, -2.0, 0.0, 0.0, 1.0]],
[[2.0, 1.0, 0.0, 1.0, 1.0], [-1.0, -2.0, 0.0, 2.0, 0.0]],
]
]
)
assert_expected(
actual_quantized_flat, expected_quantized_flat, rtol=0.0, atol=1e-4
)
assert_expected(actual_quantized, expected_quantized, rtol=0.0, atol=1e-4)