| | |
| | |
| | |
| | |
| | |
| |
|
| | import pytest |
| |
|
| | import torch |
| | from tests.test_utils import assert_expected |
| | from torch import nn |
| | from torchmultimodal.modules.layers.position_embedding import ( |
| | BroadcastedPositionEmbedding, |
| | SinusoidalPositionEmbeddings, |
| | ) |
| |
|
| |
|
| | class TestBroadcastedPositionEmbedding: |
| | @pytest.fixture(scope="class") |
| | def pos_emb(self): |
| | _pos_emb = BroadcastedPositionEmbedding( |
| | latent_shape=(1, 2, 3), |
| | embedding_dim=6, |
| | ) |
| | _pos_emb.embedding = nn.ParameterDict( |
| | { |
| | "d_0": nn.Parameter(torch.tensor([[0.0, 1.0]])), |
| | "d_1": nn.Parameter(torch.tensor([[2.0, 3.0], [4.0, 5.0]])), |
| | "d_2": nn.Parameter(torch.tensor([[6.0, 7.0], [8.0, 9.0], [0.0, 1.0]])), |
| | } |
| | ) |
| |
|
| | return _pos_emb |
| |
|
| | def test_init_sets_embedding(self, pos_emb): |
| | """Test the embeddings are initialized with the correct dimensions""" |
| | expected = [(1, 2), (2, 2), (3, 2)] |
| | for i, (key, _) in enumerate(pos_emb.embedding.items()): |
| | assert_expected(pos_emb.embedding[key].shape, expected[i]) |
| |
|
| | def test_init_bad_embedding_dim(self): |
| | """Test raising error when the embedding dim is not allowed""" |
| | with pytest.raises(ValueError): |
| | BroadcastedPositionEmbedding(latent_shape=(1, 2, 3), embedding_dim=5) |
| |
|
| | def test_broadcast(self, pos_emb): |
| | """Test embedding along each dim is broadcasted correctly""" |
| | expected = [ |
| | torch.tensor( |
| | [ |
| | [ |
| | [[0.0, 1.0], [0.0, 1.0], [0.0, 1.0]], |
| | [[0.0, 1.0], [0.0, 1.0], [0.0, 1.0]], |
| | ], |
| | ] |
| | ), |
| | torch.tensor( |
| | [ |
| | [ |
| | [[2.0, 3.0], [2.0, 3.0], [2.0, 3.0]], |
| | [[4.0, 5.0], [4.0, 5.0], [4.0, 5.0]], |
| | ], |
| | ] |
| | ), |
| | torch.tensor( |
| | [ |
| | [ |
| | [[6.0, 7.0], [8.0, 9.0], [0.0, 1.0]], |
| | [[6.0, 7.0], [8.0, 9.0], [0.0, 1.0]], |
| | ], |
| | ] |
| | ), |
| | ] |
| | for i in range(pos_emb.n_dim): |
| | assert_expected(pos_emb._broadcast(i), expected[i]) |
| |
|
| | def test_forward(self, pos_emb): |
| | """Test the correct embeddings are returned for the given position ids""" |
| | position_ids = torch.tensor([[1, 3, -1]]) |
| | actual = pos_emb(position_ids) |
| | expected = torch.tensor( |
| | [ |
| | [ |
| | [0.0, 1.0, 2.0, 3.0, 8.0, 9.0], |
| | [0.0, 1.0, 4.0, 5.0, 6.0, 7.0], |
| | [0.0, 1.0, 4.0, 5.0, 0.0, 1.0], |
| | ] |
| | ] |
| | ) |
| | assert_expected(actual, expected) |
| |
|
| | def test_forward_invalid_input(self, pos_emb): |
| | """Test raising error when position ids contain illegal values""" |
| | with pytest.raises(IndexError) as exc_info: |
| | pos_emb(position_ids=torch.tensor([[-2, 0]])) |
| | assert exc_info.value.args[0] == "Invalid position ids: tensor([-2])" |
| | with pytest.raises(IndexError) as exc_info: |
| | pos_emb(position_ids=torch.tensor([[0, 6]])) |
| | assert exc_info.value.args[0] == "Invalid position ids: tensor([6])" |
| |
|
| |
|
| | class TestSinusoidalPositionEmbeddings: |
| | @pytest.fixture |
| | def data(self): |
| | return torch.Tensor([1, 2, 3]) |
| |
|
| | @pytest.fixture |
| | def emb(self): |
| | return SinusoidalPositionEmbeddings(5) |
| |
|
| | def test_forward(self, data, emb): |
| | actual = emb(data) |
| | expected = torch.Size([3, 5]) |
| | assert_expected(actual.shape, expected) |
| |
|