| """Tests for VAE2DSequencer and VAEAudioSequencer.""" |
| import os |
| import torch |
| import pytest |
| from arbitor.sequencers import VAE2DSequencer, VAEAudioSequencer |
|
|
| pytestmark = pytest.mark.skipif( |
| os.environ.get("ARB_RUN_SLOW_TESTS") != "1", |
| reason="VAE sequencer tests load full sidecar encoders and 7168d projections", |
| ) |
|
|
|
|
| def test_vae2d_sequencer_output_shape(): |
| seq = VAE2DSequencer() |
| img = torch.randn(2, 3, 256, 256) |
| out = seq(img) |
| assert out.shape == (2, 1024, 7168) |
|
|
|
|
| def test_vae2d_sequencer_224(): |
| seq = VAE2DSequencer() |
| img = torch.randn(1, 3, 224, 224) |
| out = seq(img) |
| assert out.shape == (1, 784, 7168) |
|
|
|
|
| def test_vae2d_sequencer_different_resolutions(): |
| seq = VAE2DSequencer() |
| for h, w in [(128, 128), (256, 192), (512, 512)]: |
| img = torch.randn(1, 3, h, w) |
| out = seq(img) |
| assert out.shape[-1] == 7168 |
| assert out.shape[1] == (h // 8) * (w // 8) |
|
|
|
|
| def test_vae2d_sequencer_no_vit_params(): |
| seq = VAE2DSequencer() |
| n_params = sum(p.numel() for p in seq.parameters() if p.requires_grad) |
| assert n_params < 100_000 |
|
|
|
|
| def test_vae2d_sequencer_output_range(): |
| seq = VAE2DSequencer() |
| img = torch.randn(1, 3, 256, 256) |
| out = seq(img) |
| assert torch.isfinite(out).all() |
| assert out.abs().mean() < 100.0 |
|
|
|
|
| def test_vae2d_sequencer_batch(): |
| seq = VAE2DSequencer() |
| imgs = torch.randn(4, 3, 256, 256) |
| out = seq(imgs) |
| assert out.shape[0] == 4 |
|
|
|
|
| def test_vae_audio_sequencer_output_shape(): |
| seq = VAEAudioSequencer() |
| audio = torch.randn(1, 48000) |
| out = seq(audio) |
| assert out.shape[-1] == 7168 |
| assert out.shape[0] == 1 |
|
|
|
|
| def test_vae_audio_sequencer_mono_tensor(): |
| seq = VAEAudioSequencer() |
| audio = torch.randn(1, 1, 16000) |
| out = seq(audio) |
| assert out.shape[-1] == 7168 |
|
|
|
|
| def test_vae_audio_sequencer_batch(): |
| seq = VAEAudioSequencer() |
| audios = torch.randn(2, 16000) |
| out = seq(audios) |
| assert out.shape[0] == 2 |
|
|
|
|
| def test_vae_audio_no_moonshine_params(): |
| seq = VAEAudioSequencer() |
| n_trainable = sum(p.numel() for p in seq.parameters() if p.requires_grad) |
| assert n_trainable < 100_000 |
|
|
|
|
| def test_vae_audio_output_range(): |
| seq = VAEAudioSequencer() |
| audio = torch.randn(1, 16000) |
| out = seq(audio) |
| assert torch.isfinite(out).all() |
|
|
|
|
| def test_vae_audio_variable_length(): |
| seq = VAEAudioSequencer() |
| short = torch.randn(1, 8000) |
| long = torch.randn(1, 48000) |
| out_short = seq(short) |
| out_long = seq(long) |
| assert out_short.shape[1] < out_long.shape[1] |
|
|