ARBS / tests /test_vae2d.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""Tests for VAE2DEncoder and MelSpectrogram3Band."""
import math
import os
import torch
import pytest
from arbitor.encoders.vae2d import load_vae2d
from arbitor.encoders.mel_frontend import MelSpectrogram3Band
pytestmark = pytest.mark.skipif(
os.environ.get("ARB_RUN_SLOW_TESTS") != "1",
reason="VAE2D sidecar tests load the full PixArt/OpenSora VAE path",
)
def test_vae2d_encoder_output_shape():
encoder = load_vae2d("cpu")
img = torch.randn(1, 3, 256, 256)
latent = encoder(img)
assert latent.shape == (1, 4, 32, 32)
def test_vae2d_encoder_requires_divisible_by_8():
encoder = load_vae2d("cpu")
img = torch.randn(1, 3, 224, 224)
latent = encoder(img)
assert latent.shape == (1, 4, 28, 28)
def test_mel_3band_output_shape():
mel = MelSpectrogram3Band(sample_rate=16000)
audio = torch.randn(1, 80000)
spec = mel(audio)
T_mel = math.ceil(80000 / 512)
assert spec.shape == (1, 3, 64, T_mel)
def test_mel_3band_channels_distinct():
audio = torch.randn(1, 16000)
spec = MelSpectrogram3Band()(audio)
assert not torch.allclose(spec[0, 0], spec[0, 1])
assert not torch.allclose(spec[0, 1], spec[0, 2])
def test_vae2d_frozen():
encoder = load_vae2d("cpu")
for p in encoder.parameters():
assert not p.requires_grad
def test_vae2d_no_decoder():
encoder = load_vae2d("cpu")
total = sum(p.numel() for p in encoder.parameters())
assert total < 60_000_000
def test_vae2d_batch_independence():
encoder = load_vae2d("cpu")
imgs = torch.randn(2, 3, 256, 256)
latent = encoder(imgs)
assert latent.shape[0] == 2
assert not torch.allclose(latent[0], latent[1])
def test_vae2d_on_mel_spectrogram():
encoder = load_vae2d("cpu")
mel = MelSpectrogram3Band(sample_rate=16000)
length = 48641
audio = torch.randn(1, length)
spec = mel(audio)
assert spec.shape[-1] % 8 == 0
latent = encoder(spec)
assert latent.shape[1] == 4
assert latent.shape[2] == 8 # 64 mel bands / 8 = 8
T_latent = spec.shape[-1] // 8
assert latent.shape[3] == T_latent