| | |
| | |
| | |
| | |
| | |
| |
|
| | import pytest |
| |
|
| | import torch |
| | from tests.test_utils import assert_expected, set_rng_seed |
| | from torch import nn |
| | from torchmultimodal.modules.layers.patch_embedding import PatchEmbeddings |
| |
|
| |
|
| | @pytest.fixture(autouse=True) |
| | def random(): |
| | set_rng_seed(0) |
| |
|
| |
|
| | @pytest.fixture |
| | def inputs(): |
| | return torch.ones(2, 3, 2, 2) |
| |
|
| |
|
| | @pytest.fixture |
| | def mask(): |
| | return torch.tensor([[1, 1, 0, 1], [0, 1, 1, 0]]) |
| |
|
| |
|
| | class TestPatchEmbeddings: |
| | def _init_conv_proj(self, model): |
| | model.conv_projection.weight = nn.Parameter( |
| | torch.tensor([[[[0.0]], [[1.0]], [[2.0]]], [[[3.0]], [[4.0]], [[5.0]]]]) |
| | ) |
| |
|
| | @pytest.fixture |
| | def embedding(self): |
| | model = PatchEmbeddings( |
| | image_size=2, |
| | patch_size=1, |
| | hidden_size=2, |
| | use_image_masking=True, |
| | ) |
| | assert model.conv_projection.bias.sum().item() == 0 |
| | self._init_conv_proj(model) |
| | model.eval() |
| | return model |
| |
|
| | @pytest.fixture |
| | def embedding_patches_dropped(self): |
| | model = PatchEmbeddings( |
| | image_size=2, |
| | patch_size=1, |
| | hidden_size=2, |
| | use_image_masking=False, |
| | patch_drop_rate=0.5, |
| | ) |
| | self._init_conv_proj(model) |
| | return model |
| |
|
| | def test_forward(self, inputs, embedding): |
| | actual = embedding(inputs).embeddings |
| | expected = torch.Tensor( |
| | [ |
| | [[0.0, 0.0], [3.0, 12.0], [3.0, 12.0], [3.0, 12.0], [3.0, 12.0]], |
| | [[0.0, 0.0], [3.0, 12.0], [3.0, 12.0], [3.0, 12.0], [3.0, 12.0]], |
| | ] |
| | ) |
| | assert_expected(actual, expected, atol=1e-4, rtol=0) |
| |
|
| | def test_forward_masked(self, inputs, mask, embedding): |
| | actual = embedding(inputs, image_patches_mask=mask).embeddings |
| | expected = torch.Tensor( |
| | [ |
| | [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [3.0, 12.0], [0.0, 0.0]], |
| | [[0.0, 0.0], [3.0, 12.0], [0.0, 0.0], [0.0, 0.0], [3.0, 12.0]], |
| | ] |
| | ) |
| | assert_expected(actual, expected, atol=1e-4, rtol=0) |
| |
|
| | def test_forward_patches_dropped(self, inputs, embedding_patches_dropped): |
| | actual = embedding_patches_dropped(inputs).embeddings |
| | expected = torch.Tensor( |
| | [ |
| | [[0.0, 0.0], [3.0, 12.0], [3.0, 12.0]], |
| | [[0.0, 0.0], [3.0, 12.0], [3.0, 12.0]], |
| | ] |
| | ) |
| | assert_expected(actual, expected, atol=1e-4, rtol=0) |
| |
|
| | def test_forward_rectangle_input(self): |
| | model = PatchEmbeddings( |
| | image_size=(4, 6), |
| | patch_size=2, |
| | hidden_size=2, |
| | use_image_masking=False, |
| | num_channels=1, |
| | ) |
| | model.conv_projection.weight = nn.Parameter( |
| | torch.tensor([[[[0.0, 0.0], [0.0, 0.0]]], [[[3.0, 3.0], [3.0, 3.0]]]]) |
| | ) |
| | model.eval() |
| | actual = model(torch.ones(1, 1, 4, 6)).embeddings |
| | expected = torch.Tensor( |
| | [ |
| | [ |
| | [0.0, 0.0], |
| | [0.0, 12.0], |
| | [0.0, 12.0], |
| | [0.0, 12.0], |
| | [0.0, 12.0], |
| | [0.0, 12.0], |
| | [0.0, 12.0], |
| | ], |
| | ] |
| | ) |
| | assert_expected(actual, expected, atol=1e-4, rtol=0) |
| |
|
| | def test_forward_no_cls(self, inputs, mask): |
| | embedding = PatchEmbeddings( |
| | image_size=2, |
| | patch_size=1, |
| | hidden_size=2, |
| | use_image_masking=True, |
| | include_cls_embed=False, |
| | ) |
| | self._init_conv_proj(embedding) |
| | actual = embedding(inputs).embeddings |
| | expected = torch.Tensor( |
| | [ |
| | [[3.0, 12.0], [3.0, 12.0], [3.0, 12.0], [3.0, 12.0]], |
| | [[3.0, 12.0], [3.0, 12.0], [3.0, 12.0], [3.0, 12.0]], |
| | ] |
| | ) |
| | assert_expected(actual, expected, atol=1e-4, rtol=0) |
| |
|