vlm_clone_2 / multimodal /tests /diffusion_labs /test_vae_encoder_decoder.py
tuandunghcmut's picture
Add files using upload-large-folder tool
f233443 verified
#!/usr/bin/env fbpython
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from functools import partial
import pytest
import torch
from tests.test_utils import assert_expected, set_rng_seed
from torchmultimodal.diffusion_labs.models.vae.encoder_decoder import (
res_block,
res_block_stack,
ResNetDecoder,
ResNetEncoder,
)
@pytest.fixture(autouse=True)
def set_seed():
set_rng_seed(54321)
@pytest.fixture
def in_channels():
return 2
@pytest.fixture
def out_channels():
return 5
@pytest.fixture
def z_channels():
return 3
@pytest.fixture
def channels():
return 4
@pytest.fixture
def num_res_blocks():
return 2
@pytest.fixture
def channel_multipliers():
return (1, 2)
@pytest.fixture
def norm_groups():
return 2
@pytest.fixture
def norm_eps():
return 1e-05
@pytest.fixture
def x(in_channels):
bsize = 2
height = 16
width = 16
return torch.randn(bsize, in_channels, height, width)
@pytest.fixture
def z(z_channels):
bsize = 2
height = 4
width = 4
return torch.randn(bsize, z_channels, height, width)
class TestResNetEncoder:
@pytest.fixture
def encoder(
self,
in_channels,
z_channels,
channels,
num_res_blocks,
channel_multipliers,
norm_groups,
norm_eps,
):
return partial(
ResNetEncoder,
in_channels=in_channels,
z_channels=z_channels,
channels=channels,
num_res_blocks=num_res_blocks,
channel_multipliers=channel_multipliers,
norm_groups=norm_groups,
norm_eps=norm_eps,
)
@pytest.mark.parametrize("double_z", [True, False])
def test_forward_dims(self, encoder, x, z_channels, channel_multipliers, double_z):
encoder_module = encoder(double_z=double_z)
output = encoder_module(x)
assert_expected(
output.size(),
torch.Size(
[
x.size(0),
z_channels * (2 if double_z else 1),
x.size(2) // 2 ** (len(channel_multipliers) - 1),
x.size(3) // 2 ** (len(channel_multipliers) - 1),
]
),
)
def test_forward(self, encoder, x):
encoder_module = encoder()
actual = encoder_module(x)
expected = torch.tensor(126.5277)
assert_expected(actual.sum(), expected, rtol=0, atol=1e-4)
def test_channel_indivisble_norm_group_error(self, encoder):
with pytest.raises(ValueError):
_ = encoder(norm_groups=7)
class TestResNetDecoder:
@pytest.fixture
def decoder(
self,
out_channels,
z_channels,
channels,
num_res_blocks,
channel_multipliers,
norm_groups,
norm_eps,
):
return partial(
ResNetDecoder,
out_channels=out_channels,
z_channels=z_channels,
channels=channels,
num_res_blocks=num_res_blocks,
channel_multipliers=channel_multipliers,
norm_groups=norm_groups,
norm_eps=norm_eps,
)
@pytest.mark.parametrize("output_alpha_channel", [True, False])
def test_forward_dims(
self, decoder, z, out_channels, channel_multipliers, output_alpha_channel
):
decoder_module = decoder(output_alpha_channel=output_alpha_channel)
output = decoder_module(z)
assert_expected(
output.size(),
torch.Size(
[
z.size(0),
out_channels + (1 if output_alpha_channel else 0),
z.size(2) * 2 ** (len(channel_multipliers) - 1),
z.size(3) * 2 ** (len(channel_multipliers) - 1),
]
),
)
def test_forward(self, decoder, z):
decoder_module = decoder()
actual = decoder_module(z)
expected = torch.tensor(-10.0260)
assert_expected(actual.sum(), expected, rtol=0, atol=1e-4)
def test_forward_alpha_channel(self, decoder, z):
decoder_module = decoder(output_alpha_channel=True)
actual = decoder_module(z)
expected = torch.tensor(-16.2157)
assert_expected(actual.sum(), expected, rtol=0, atol=1e-4)
def test_channel_indivisble_norm_group_error(self, decoder):
with pytest.raises(ValueError):
_ = decoder(norm_groups=7)
@pytest.mark.parametrize("out_channels,expected_value", [(2, 52.2716), (4, 152.8285)])
def test_res_block(x, out_channels, expected_value):
in_channels = x.size(1)
res_block_module = res_block(in_channels, out_channels, dropout=0.3, norm_groups=1)
actual = res_block_module(x)
expected = torch.tensor(expected_value)
assert_expected(
actual.size(),
torch.Size(
[
x.size(0),
out_channels,
x.size(2),
x.size(3),
]
),
)
assert_expected(actual.sum(), expected, rtol=0, atol=1e-4)
@pytest.mark.parametrize(
"needs_upsample,needs_downsample,expected_value",
[(False, True, 28.02428), (False, False, 382.8569), (True, False, 581.62414)],
)
def test_res_block_stack(
x,
in_channels,
channels,
num_res_blocks,
needs_upsample,
needs_downsample,
expected_value,
):
res_block_stack_module = res_block_stack(
in_channels=in_channels,
out_channels=channels,
num_blocks=num_res_blocks,
dropout=0.1,
needs_upsample=needs_upsample,
needs_downsample=needs_downsample,
norm_groups=1,
)
actual = res_block_stack_module(x)
expected = torch.tensor(expected_value)
if needs_upsample:
size_multipler = 2
elif needs_downsample:
size_multipler = 0.5
else:
size_multipler = 1
assert_expected(
actual.size(),
torch.Size(
[
x.size(0),
channels,
int(x.size(2) * size_multipler),
int(x.size(3) * size_multipler),
]
),
)
assert_expected(actual.sum(), expected, rtol=0, atol=1e-4)
def test_res_block_stack_exception(
in_channels,
channels,
num_res_blocks,
):
with pytest.raises(ValueError):
_ = res_block_stack(
in_channels=in_channels,
out_channels=channels,
num_blocks=num_res_blocks,
needs_upsample=True,
needs_downsample=True,
)