File size: 2,556 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""Tests for VAE2DSequencer and VAEAudioSequencer."""
import os
import torch
import pytest
from arbitor.sequencers import VAE2DSequencer, VAEAudioSequencer

pytestmark = pytest.mark.skipif(
    os.environ.get("ARB_RUN_SLOW_TESTS") != "1",
    reason="VAE sequencer tests load full sidecar encoders and 7168d projections",
)


def test_vae2d_sequencer_output_shape():
    seq = VAE2DSequencer()
    img = torch.randn(2, 3, 256, 256)
    out = seq(img)
    assert out.shape == (2, 1024, 7168)


def test_vae2d_sequencer_224():
    seq = VAE2DSequencer()
    img = torch.randn(1, 3, 224, 224)
    out = seq(img)
    assert out.shape == (1, 784, 7168)


def test_vae2d_sequencer_different_resolutions():
    seq = VAE2DSequencer()
    for h, w in [(128, 128), (256, 192), (512, 512)]:
        img = torch.randn(1, 3, h, w)
        out = seq(img)
        assert out.shape[-1] == 7168
        assert out.shape[1] == (h // 8) * (w // 8)


def test_vae2d_sequencer_no_vit_params():
    seq = VAE2DSequencer()
    n_params = sum(p.numel() for p in seq.parameters() if p.requires_grad)
    assert n_params < 100_000


def test_vae2d_sequencer_output_range():
    seq = VAE2DSequencer()
    img = torch.randn(1, 3, 256, 256)
    out = seq(img)
    assert torch.isfinite(out).all()
    assert out.abs().mean() < 100.0


def test_vae2d_sequencer_batch():
    seq = VAE2DSequencer()
    imgs = torch.randn(4, 3, 256, 256)
    out = seq(imgs)
    assert out.shape[0] == 4


def test_vae_audio_sequencer_output_shape():
    seq = VAEAudioSequencer()
    audio = torch.randn(1, 48000)
    out = seq(audio)
    assert out.shape[-1] == 7168
    assert out.shape[0] == 1


def test_vae_audio_sequencer_mono_tensor():
    seq = VAEAudioSequencer()
    audio = torch.randn(1, 1, 16000)
    out = seq(audio)
    assert out.shape[-1] == 7168


def test_vae_audio_sequencer_batch():
    seq = VAEAudioSequencer()
    audios = torch.randn(2, 16000)
    out = seq(audios)
    assert out.shape[0] == 2


def test_vae_audio_no_moonshine_params():
    seq = VAEAudioSequencer()
    n_trainable = sum(p.numel() for p in seq.parameters() if p.requires_grad)
    assert n_trainable < 100_000


def test_vae_audio_output_range():
    seq = VAEAudioSequencer()
    audio = torch.randn(1, 16000)
    out = seq(audio)
    assert torch.isfinite(out).all()


def test_vae_audio_variable_length():
    seq = VAEAudioSequencer()
    short = torch.randn(1, 8000)
    long = torch.randn(1, 48000)
    out_short = seq(short)
    out_long = seq(long)
    assert out_short.shape[1] < out_long.shape[1]