| """Cross-modal integration tests for Phase 20.""" |
| import os |
| import torch |
| import pytest |
| from arbitor.main import ARBModel |
|
|
| pytestmark = pytest.mark.skipif( |
| os.environ.get("ARB_RUN_SLOW_TESTS") != "1", |
| reason="full cross-modal ARBModel tests require the 3B target model and sidecar encoders", |
| ) |
|
|
|
|
| def test_cross_modality_unified_latent(): |
| model = ARBModel(enable_image=True, enable_audio=True) |
| model.eval() |
| text = torch.randint(0, 288, (1, 50)) |
| img = torch.randn(1, 3, 256, 256) |
| audio = torch.randn(1, 16000 * 3) |
| logits, losses, indices, _ = model(text, images=img, audio=audio) |
| assert logits is not None |
| assert indices is not None |
| assert indices.shape[1] > 50 |
|
|
|
|
| def test_text_only_still_works(): |
| model = ARBModel(enable_image=False, enable_audio=False) |
| model.eval() |
| text = torch.randint(0, 288, (1, 50)) |
| logits, losses, indices, _ = model(text) |
| assert logits is not None |
| assert logits.shape[-1] == 288 |
|
|
|
|
| def test_image_only(): |
| model = ARBModel(enable_image=True, enable_audio=False) |
| model.eval() |
| text = torch.randint(0, 288, (1, 10)) |
| img = torch.randn(1, 3, 224, 224) |
| logits, losses, indices, _ = model(text, images=img) |
| assert logits is not None |
| assert indices.shape[1] > 10 |
|
|