File size: 11,334 Bytes
f233443 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 | # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from collections import OrderedDict
import pytest
import torch
from tests.test_utils import assert_expected, assert_expected_namedtuple, set_rng_seed
from torch import nn, tensor
from torchmultimodal.modules.layers.codebook import Codebook
@pytest.fixture(autouse=True)
def random_seed():
set_rng_seed(4)
@pytest.fixture
def num_embeddings():
return 4
@pytest.fixture
def embedding_dim():
return 5
@pytest.fixture
def encoded():
# This is 2x5x3
encoded = tensor(
[
[
[-1.0, 0.0, 1.0],
[2.0, 1.0, 0.0],
[0.0, -1.0, -1.0],
[0.0, 2.0, -1.0],
[-2.0, -1.0, 1.0],
],
[
[2.0, 2.0, -1.0],
[1.0, -1.0, -2.0],
[0.0, 0.0, 0.0],
[1.0, 2.0, 1.0],
[1.0, 0.0, 0.0],
],
]
)
encoded.requires_grad_()
return encoded
@pytest.fixture
def embedding_weights():
# This is 4x5
return tensor(
[
[1.0, 0.0, -1.0, -1.0, 2.0],
[2.0, -2.0, 0.0, 0.0, 1.0],
[2.0, 1.0, 0.0, 1.0, 1.0],
[-1.0, -2.0, 0.0, 2.0, 0.0],
]
)
@pytest.fixture
def input_tensor_flat():
# This is 4x3
return tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])
@pytest.fixture
def codebook(num_embeddings, embedding_dim):
return Codebook(
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
decay=0.3,
)
class TestCodebook:
def test_quantized_output(self, codebook, embedding_weights, encoded):
codebook.embedding = embedding_weights
codebook._is_embedding_init = True
actual = codebook(encoded)
# This is shape (2,5,3)
expected_quantized = tensor(
[
[
[2.0, 2.0, 1.0],
[1.0, 1.0, 0.0],
[0.0, 0.0, -1.0],
[1.0, 1.0, -1.0],
[1.0, 1.0, 2.0],
],
[
[2.0, 2.0, -1.0],
[1.0, -2.0, -2.0],
[0.0, 0.0, 0.0],
[1.0, 0.0, 2.0],
[1.0, 1.0, 0.0],
],
]
)
expected_quantized_flat = (
expected_quantized.permute(0, 2, 1).contiguous().view(-1, 5)
)
expected = {
"encoded_flat": encoded.permute(0, 2, 1).contiguous().view(-1, 5),
"quantized_flat": expected_quantized_flat,
"codebook_indices": tensor([[2.0, 2.0, 0.0], [2.0, 1.0, 3.0]]).type(
torch.LongTensor
),
"quantized": expected_quantized,
}
assert_expected_namedtuple(actual, expected)
def test_preprocess(self, codebook, encoded):
encoded_flat, permuted_shape = codebook._preprocess(encoded)
expected_flat_shape = torch.tensor([6, 5])
expected_permuted_shape = torch.tensor([2, 3, 5])
actual_flat_shape = torch.tensor(encoded_flat.shape)
actual_permuted_shape = torch.tensor(permuted_shape)
assert_expected(actual_flat_shape, expected_flat_shape)
assert_expected(actual_permuted_shape, expected_permuted_shape)
def test_preprocess_channel_dim_assertion(self, codebook, encoded):
with pytest.raises(ValueError):
codebook._preprocess(encoded[:, :4, :])
def test_postprocess(self, codebook, input_tensor_flat):
quantized = codebook._postprocess(input_tensor_flat, torch.Size([2, 2, 3]))
actual_quantized_shape = torch.tensor(quantized.shape)
expected_quantized_shape = torch.tensor([2, 3, 2])
assert_expected(actual_quantized_shape, expected_quantized_shape)
def test_init_embedding(self, codebook, encoded, num_embeddings):
assert (
not codebook._is_embedding_init
), "embedding init flag not False initially"
encoded_flat, _ = codebook._preprocess(encoded)
codebook._init_embedding(encoded_flat)
assert codebook._is_embedding_init, "embedding init flag not True after init"
actual_weight = codebook.embedding
expected_weight = tensor(
[
[2.0, -1.0, 0.0, 2.0, 0.0],
[2.0, 1.0, 0.0, 1.0, 1.0],
[0.0, 1.0, -1.0, 2.0, -1.0],
[1.0, 0.0, -1.0, -1.0, 1.0],
]
)
assert_expected(actual_weight, expected_weight)
actual_code_avg = codebook.code_avg
expected_code_avg = actual_weight
assert_expected(actual_code_avg, expected_code_avg)
actual_code_usage = codebook.code_usage
expected_code_usage = torch.ones(num_embeddings)
assert_expected(actual_code_usage, expected_code_usage)
def test_ema_update_embedding(self, codebook, encoded):
encoded_flat, _ = codebook._preprocess(encoded)
codebook._init_embedding(encoded_flat)
distances = torch.cdist(encoded_flat, codebook.embedding, p=2.0) ** 2
codebook_indices = torch.argmin(distances, dim=1)
codebook._ema_update_embedding(encoded_flat, codebook_indices)
actual_weight = codebook.embedding
expected_weight = tensor(
[
[0.7647, -1.4118, 0.0000, 1.5882, 0.0000],
[2.0000, 1.0000, 0.0000, 1.0000, 1.0000],
[-0.4118, 1.4118, -0.5882, 1.1765, -1.4118],
[1.0000, 0.0000, -1.0000, -1.0000, 1.0000],
]
)
assert_expected(actual_weight, expected_weight, rtol=0.0, atol=1e-4)
actual_code_avg = codebook.code_avg
expected_code_avg = tensor(
[
[1.3000, -2.4000, 0.0000, 2.7000, 0.0000],
[2.0000, 1.0000, 0.0000, 1.0000, 1.0000],
[-0.7000, 2.4000, -1.0000, 2.0000, -2.4000],
[1.0000, 0.0000, -1.0000, -1.0000, 1.0000],
]
)
assert_expected(actual_code_avg, expected_code_avg, rtol=0.0, atol=1e-4)
actual_code_usage = codebook.code_usage
expected_code_usage = tensor([1.7000, 1.0000, 1.7000, 1.0000])
assert_expected(actual_code_usage, expected_code_usage, rtol=0.0, atol=1e-4)
def test_register_buffer_tensors(self, codebook, encoded):
out = codebook(encoded)
out.quantized.sum().backward()
msg_has_grad = "tensor assigned to buffer but accumulated grad"
with warnings.catch_warnings():
warnings.simplefilter("ignore")
assert not codebook.code_avg.grad, msg_has_grad
assert not codebook.code_usage.grad, msg_has_grad
assert not codebook.embedding.grad, msg_has_grad
assert not list(
codebook.parameters()
), "buffer variables incorrectly assigned as params"
def test_init_embedding_smaller_encoded(self, codebook, encoded):
encoded_small = encoded[:1, :, :2]
encoded_small_flat, _ = codebook._preprocess(encoded_small)
codebook._init_embedding(encoded_small_flat)
embed = codebook.embedding
# Check for each embedding vector if there is one equal encoded vector + noise
for emb in embed:
assert any(
[
torch.isclose(emb, enc, rtol=0, atol=0.01).all()
for enc in encoded_small_flat
]
), "embedding initialized from encoder output incorrectly"
def test_codebook_restart(self, codebook, encoded):
encoded_flat, _ = codebook._preprocess(encoded)
# First init and diversify embedding
codebook._init_embedding(encoded_flat)
# Use only embedding vector at index = 1 and force restarts.
# Slightly modify encoded_flat to make sure vectors restart to something new
encoded_flat_noise = encoded_flat + torch.randn_like(encoded_flat)
codebook_indices_low_usage = torch.ones(encoded_flat.shape[0], dtype=torch.long)
codebook._ema_update_embedding(encoded_flat_noise, codebook_indices_low_usage)
# Check if embedding contains restarts
for i, emb in enumerate(codebook.embedding):
# We used only emb vector with index = 1, so check it was not restarted
if i == 1:
assert_expected(
emb,
codebook.code_avg[1] / codebook.code_usage[1],
rtol=0,
atol=1e-4,
)
# Compare each embedding vector to each encoded vector.
# If at least one match, then restart happened.
else:
assert any(
[
torch.isclose(emb, enc, rtol=0, atol=1e-4).all()
for enc in encoded_flat_noise
]
), "embedding restarted from encoder output incorrectly"
def test_load_state_dict(self):
state_dict = OrderedDict(
[
("linear.weight", tensor([[1.0]])),
("linear.bias", tensor([2.0])),
("codebook.embedding", tensor([[3.0]])),
("codebook.code_usage", tensor([4.0])),
("codebook.code_avg", tensor([[5.0]])),
]
)
class DummyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(1, 1)
self.codebook = Codebook(1, 1)
model = DummyModel()
assert not model.codebook._is_embedding_init
model.load_state_dict(state_dict)
assert model.codebook._is_embedding_init
actual = model.codebook.embedding
expected = state_dict["codebook.embedding"]
assert_expected(actual, expected)
actual = model.codebook.code_usage
expected = state_dict["codebook.code_usage"]
assert_expected(actual, expected)
actual = model.codebook.code_avg
expected = state_dict["codebook.code_avg"]
assert_expected(actual, expected)
def test_lookup(self, codebook, embedding_weights):
codebook.embedding = embedding_weights
indices_flat = tensor([[0, 1]]) # (b, seq_len)
indices_shaped = tensor([[[0, 1], [2, 3]]]) # (b, shape)
actual_quantized_flat = codebook.lookup(indices_flat)
actual_quantized = codebook.lookup(indices_shaped)
expected_quantized_flat = tensor(
[[[1.0, 0.0, -1.0, -1.0, 2.0], [2.0, -2.0, 0.0, 0.0, 1.0]]]
)
expected_quantized = tensor(
[
[
[[1.0, 0.0, -1.0, -1.0, 2.0], [2.0, -2.0, 0.0, 0.0, 1.0]],
[[2.0, 1.0, 0.0, 1.0, 1.0], [-1.0, -2.0, 0.0, 2.0, 0.0]],
]
]
)
assert_expected(
actual_quantized_flat, expected_quantized_flat, rtol=0.0, atol=1e-4
)
assert_expected(actual_quantized, expected_quantized, rtol=0.0, atol=1e-4)
|