ARBS / tests /test_vae2d_sequencer.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""Tests for VAE2DSequencer and VAEAudioSequencer."""
import os
import torch
import pytest
from arbitor.sequencers import VAE2DSequencer, VAEAudioSequencer
pytestmark = pytest.mark.skipif(
os.environ.get("ARB_RUN_SLOW_TESTS") != "1",
reason="VAE sequencer tests load full sidecar encoders and 7168d projections",
)
def test_vae2d_sequencer_output_shape():
seq = VAE2DSequencer()
img = torch.randn(2, 3, 256, 256)
out = seq(img)
assert out.shape == (2, 1024, 7168)
def test_vae2d_sequencer_224():
seq = VAE2DSequencer()
img = torch.randn(1, 3, 224, 224)
out = seq(img)
assert out.shape == (1, 784, 7168)
def test_vae2d_sequencer_different_resolutions():
seq = VAE2DSequencer()
for h, w in [(128, 128), (256, 192), (512, 512)]:
img = torch.randn(1, 3, h, w)
out = seq(img)
assert out.shape[-1] == 7168
assert out.shape[1] == (h // 8) * (w // 8)
def test_vae2d_sequencer_no_vit_params():
seq = VAE2DSequencer()
n_params = sum(p.numel() for p in seq.parameters() if p.requires_grad)
assert n_params < 100_000
def test_vae2d_sequencer_output_range():
seq = VAE2DSequencer()
img = torch.randn(1, 3, 256, 256)
out = seq(img)
assert torch.isfinite(out).all()
assert out.abs().mean() < 100.0
def test_vae2d_sequencer_batch():
seq = VAE2DSequencer()
imgs = torch.randn(4, 3, 256, 256)
out = seq(imgs)
assert out.shape[0] == 4
def test_vae_audio_sequencer_output_shape():
seq = VAEAudioSequencer()
audio = torch.randn(1, 48000)
out = seq(audio)
assert out.shape[-1] == 7168
assert out.shape[0] == 1
def test_vae_audio_sequencer_mono_tensor():
seq = VAEAudioSequencer()
audio = torch.randn(1, 1, 16000)
out = seq(audio)
assert out.shape[-1] == 7168
def test_vae_audio_sequencer_batch():
seq = VAEAudioSequencer()
audios = torch.randn(2, 16000)
out = seq(audios)
assert out.shape[0] == 2
def test_vae_audio_no_moonshine_params():
seq = VAEAudioSequencer()
n_trainable = sum(p.numel() for p in seq.parameters() if p.requires_grad)
assert n_trainable < 100_000
def test_vae_audio_output_range():
seq = VAEAudioSequencer()
audio = torch.randn(1, 16000)
out = seq(audio)
assert torch.isfinite(out).all()
def test_vae_audio_variable_length():
seq = VAEAudioSequencer()
short = torch.randn(1, 8000)
long = torch.randn(1, 48000)
out_short = seq(short)
out_long = seq(long)
assert out_short.shape[1] < out_long.shape[1]