| """Tests for VAE2DEncoder and MelSpectrogram3Band.""" |
| import math |
| import os |
| import torch |
| import pytest |
| from arbitor.encoders.vae2d import load_vae2d |
| from arbitor.encoders.mel_frontend import MelSpectrogram3Band |
|
|
| pytestmark = pytest.mark.skipif( |
| os.environ.get("ARB_RUN_SLOW_TESTS") != "1", |
| reason="VAE2D sidecar tests load the full PixArt/OpenSora VAE path", |
| ) |
|
|
|
|
| def test_vae2d_encoder_output_shape(): |
| encoder = load_vae2d("cpu") |
| img = torch.randn(1, 3, 256, 256) |
| latent = encoder(img) |
| assert latent.shape == (1, 4, 32, 32) |
|
|
|
|
| def test_vae2d_encoder_requires_divisible_by_8(): |
| encoder = load_vae2d("cpu") |
| img = torch.randn(1, 3, 224, 224) |
| latent = encoder(img) |
| assert latent.shape == (1, 4, 28, 28) |
|
|
|
|
| def test_mel_3band_output_shape(): |
| mel = MelSpectrogram3Band(sample_rate=16000) |
| audio = torch.randn(1, 80000) |
| spec = mel(audio) |
| T_mel = math.ceil(80000 / 512) |
| assert spec.shape == (1, 3, 64, T_mel) |
|
|
|
|
| def test_mel_3band_channels_distinct(): |
| audio = torch.randn(1, 16000) |
| spec = MelSpectrogram3Band()(audio) |
| assert not torch.allclose(spec[0, 0], spec[0, 1]) |
| assert not torch.allclose(spec[0, 1], spec[0, 2]) |
|
|
|
|
| def test_vae2d_frozen(): |
| encoder = load_vae2d("cpu") |
| for p in encoder.parameters(): |
| assert not p.requires_grad |
|
|
|
|
| def test_vae2d_no_decoder(): |
| encoder = load_vae2d("cpu") |
| total = sum(p.numel() for p in encoder.parameters()) |
| assert total < 60_000_000 |
|
|
|
|
| def test_vae2d_batch_independence(): |
| encoder = load_vae2d("cpu") |
| imgs = torch.randn(2, 3, 256, 256) |
| latent = encoder(imgs) |
| assert latent.shape[0] == 2 |
| assert not torch.allclose(latent[0], latent[1]) |
|
|
|
|
| def test_vae2d_on_mel_spectrogram(): |
| encoder = load_vae2d("cpu") |
| mel = MelSpectrogram3Band(sample_rate=16000) |
| length = 48641 |
| audio = torch.randn(1, length) |
| spec = mel(audio) |
| assert spec.shape[-1] % 8 == 0 |
| latent = encoder(spec) |
| assert latent.shape[1] == 4 |
| assert latent.shape[2] == 8 |
| T_latent = spec.shape[-1] // 8 |
| assert latent.shape[3] == T_latent |
|
|