| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from functools import partial |
| |
|
| | import pytest |
| | import torch |
| | from tests.test_utils import assert_expected, set_rng_seed |
| |
|
| | from torchmultimodal.diffusion_labs.models.vae.encoder_decoder import ( |
| | res_block, |
| | res_block_stack, |
| | ResNetDecoder, |
| | ResNetEncoder, |
| | ) |
| |
|
| |
|
| | @pytest.fixture(autouse=True) |
| | def set_seed(): |
| | set_rng_seed(54321) |
| |
|
| |
|
| | @pytest.fixture |
| | def in_channels(): |
| | return 2 |
| |
|
| |
|
| | @pytest.fixture |
| | def out_channels(): |
| | return 5 |
| |
|
| |
|
| | @pytest.fixture |
| | def z_channels(): |
| | return 3 |
| |
|
| |
|
| | @pytest.fixture |
| | def channels(): |
| | return 4 |
| |
|
| |
|
| | @pytest.fixture |
| | def num_res_blocks(): |
| | return 2 |
| |
|
| |
|
| | @pytest.fixture |
| | def channel_multipliers(): |
| | return (1, 2) |
| |
|
| |
|
| | @pytest.fixture |
| | def norm_groups(): |
| | return 2 |
| |
|
| |
|
| | @pytest.fixture |
| | def norm_eps(): |
| | return 1e-05 |
| |
|
| |
|
| | @pytest.fixture |
| | def x(in_channels): |
| | bsize = 2 |
| | height = 16 |
| | width = 16 |
| | return torch.randn(bsize, in_channels, height, width) |
| |
|
| |
|
| | @pytest.fixture |
| | def z(z_channels): |
| | bsize = 2 |
| | height = 4 |
| | width = 4 |
| | return torch.randn(bsize, z_channels, height, width) |
| |
|
| |
|
| | class TestResNetEncoder: |
| | @pytest.fixture |
| | def encoder( |
| | self, |
| | in_channels, |
| | z_channels, |
| | channels, |
| | num_res_blocks, |
| | channel_multipliers, |
| | norm_groups, |
| | norm_eps, |
| | ): |
| | return partial( |
| | ResNetEncoder, |
| | in_channels=in_channels, |
| | z_channels=z_channels, |
| | channels=channels, |
| | num_res_blocks=num_res_blocks, |
| | channel_multipliers=channel_multipliers, |
| | norm_groups=norm_groups, |
| | norm_eps=norm_eps, |
| | ) |
| |
|
| | @pytest.mark.parametrize("double_z", [True, False]) |
| | def test_forward_dims(self, encoder, x, z_channels, channel_multipliers, double_z): |
| | encoder_module = encoder(double_z=double_z) |
| | output = encoder_module(x) |
| | assert_expected( |
| | output.size(), |
| | torch.Size( |
| | [ |
| | x.size(0), |
| | z_channels * (2 if double_z else 1), |
| | x.size(2) // 2 ** (len(channel_multipliers) - 1), |
| | x.size(3) // 2 ** (len(channel_multipliers) - 1), |
| | ] |
| | ), |
| | ) |
| |
|
| | def test_forward(self, encoder, x): |
| | encoder_module = encoder() |
| | actual = encoder_module(x) |
| | expected = torch.tensor(126.5277) |
| | assert_expected(actual.sum(), expected, rtol=0, atol=1e-4) |
| |
|
| | def test_channel_indivisble_norm_group_error(self, encoder): |
| | with pytest.raises(ValueError): |
| | _ = encoder(norm_groups=7) |
| |
|
| |
|
| | class TestResNetDecoder: |
| | @pytest.fixture |
| | def decoder( |
| | self, |
| | out_channels, |
| | z_channels, |
| | channels, |
| | num_res_blocks, |
| | channel_multipliers, |
| | norm_groups, |
| | norm_eps, |
| | ): |
| | return partial( |
| | ResNetDecoder, |
| | out_channels=out_channels, |
| | z_channels=z_channels, |
| | channels=channels, |
| | num_res_blocks=num_res_blocks, |
| | channel_multipliers=channel_multipliers, |
| | norm_groups=norm_groups, |
| | norm_eps=norm_eps, |
| | ) |
| |
|
| | @pytest.mark.parametrize("output_alpha_channel", [True, False]) |
| | def test_forward_dims( |
| | self, decoder, z, out_channels, channel_multipliers, output_alpha_channel |
| | ): |
| | decoder_module = decoder(output_alpha_channel=output_alpha_channel) |
| | output = decoder_module(z) |
| | assert_expected( |
| | output.size(), |
| | torch.Size( |
| | [ |
| | z.size(0), |
| | out_channels + (1 if output_alpha_channel else 0), |
| | z.size(2) * 2 ** (len(channel_multipliers) - 1), |
| | z.size(3) * 2 ** (len(channel_multipliers) - 1), |
| | ] |
| | ), |
| | ) |
| |
|
| | def test_forward(self, decoder, z): |
| | decoder_module = decoder() |
| | actual = decoder_module(z) |
| | expected = torch.tensor(-10.0260) |
| | assert_expected(actual.sum(), expected, rtol=0, atol=1e-4) |
| |
|
| | def test_forward_alpha_channel(self, decoder, z): |
| | decoder_module = decoder(output_alpha_channel=True) |
| | actual = decoder_module(z) |
| | expected = torch.tensor(-16.2157) |
| | assert_expected(actual.sum(), expected, rtol=0, atol=1e-4) |
| |
|
| | def test_channel_indivisble_norm_group_error(self, decoder): |
| | with pytest.raises(ValueError): |
| | _ = decoder(norm_groups=7) |
| |
|
| |
|
| | @pytest.mark.parametrize("out_channels,expected_value", [(2, 52.2716), (4, 152.8285)]) |
| | def test_res_block(x, out_channels, expected_value): |
| | in_channels = x.size(1) |
| | res_block_module = res_block(in_channels, out_channels, dropout=0.3, norm_groups=1) |
| | actual = res_block_module(x) |
| | expected = torch.tensor(expected_value) |
| | assert_expected( |
| | actual.size(), |
| | torch.Size( |
| | [ |
| | x.size(0), |
| | out_channels, |
| | x.size(2), |
| | x.size(3), |
| | ] |
| | ), |
| | ) |
| | assert_expected(actual.sum(), expected, rtol=0, atol=1e-4) |
| |
|
| |
|
| | @pytest.mark.parametrize( |
| | "needs_upsample,needs_downsample,expected_value", |
| | [(False, True, 28.02428), (False, False, 382.8569), (True, False, 581.62414)], |
| | ) |
| | def test_res_block_stack( |
| | x, |
| | in_channels, |
| | channels, |
| | num_res_blocks, |
| | needs_upsample, |
| | needs_downsample, |
| | expected_value, |
| | ): |
| | res_block_stack_module = res_block_stack( |
| | in_channels=in_channels, |
| | out_channels=channels, |
| | num_blocks=num_res_blocks, |
| | dropout=0.1, |
| | needs_upsample=needs_upsample, |
| | needs_downsample=needs_downsample, |
| | norm_groups=1, |
| | ) |
| | actual = res_block_stack_module(x) |
| | expected = torch.tensor(expected_value) |
| | if needs_upsample: |
| | size_multipler = 2 |
| | elif needs_downsample: |
| | size_multipler = 0.5 |
| | else: |
| | size_multipler = 1 |
| | assert_expected( |
| | actual.size(), |
| | torch.Size( |
| | [ |
| | x.size(0), |
| | channels, |
| | int(x.size(2) * size_multipler), |
| | int(x.size(3) * size_multipler), |
| | ] |
| | ), |
| | ) |
| | assert_expected(actual.sum(), expected, rtol=0, atol=1e-4) |
| |
|
| |
|
| | def test_res_block_stack_exception( |
| | in_channels, |
| | channels, |
| | num_res_blocks, |
| | ): |
| | with pytest.raises(ValueError): |
| | _ = res_block_stack( |
| | in_channels=in_channels, |
| | out_channels=channels, |
| | num_blocks=num_res_blocks, |
| | needs_upsample=True, |
| | needs_downsample=True, |
| | ) |
| |
|