Tri-Modal Transformer with Hard-coded Modality Routing (Toy-Scale, 12.1M params)
A toy-scale (12.1M parameters) multimodal Transformer that simultaneously processes Vision, Audio, and Text using Shared Attention + Hard-coded Modality-Specific FFN routing.
Architecture
- Shared Multi-Head Attention across all modalities (cross-modal interaction)
- 3 Separated FFN experts:
FFN_vision,FFN_audio,FFN_text(hard-coded routing by modality type, no learnable gating) - Asymmetric attention mask: Vision↔Audio bidirectional, Text→Vision/Audio allowed, Vision/Audio→Text blocked, Text internal causal
Key Results (50 epochs)
| Metric | 3-way (V+A+T) | 2-way (V+T) | 2-way (A+T) |
|---|---|---|---|
| Vision MSE | 0.0397 | 0.0388 | — |
| Audio MSE | 0.0238 | — | 0.0227 |
| Text CE | 2.1804 | 2.2026 | 2.0011 |
No cross-modality interference: Adding a third modality does not degrade individual modality performance (<5% difference).
Ablation: FFN Routing Strategies (Phase 1, Vision+Text)
| Routing Strategy | Vision MSE | Text CE | Params |
|---|---|---|---|
| Separated FFN (hard-coded, ours) | 0.0388 | 2.2026 | 10.1M |
| Soft MoE (Top-1 Gating) | 0.0408 | 2.6232 | 10.1M |
| Shared FFN (no routing) | 0.0370 | 2.5434 | 9.0M |
Hard-coded routing outperforms learnable routing: Soft MoE (Switch Transformer-style Top-1 gating with load balancing) does not outperform hard-coded modality separation on Text CE. The router must learn what modality-based separation provides structurally for free.
Model Config
CONFIG = {
"d_model": 256,
"n_heads": 4,
"ffn_dim": 512,
"n_layers": 6,
"vocab_size": 10000,
"patch_size": 16,
"max_seq_len": 512,
"dropout": 0.1,
"audio_feat_dim": 768,
"audio_max_tokens": 200,
}
Training Details
- GPU: NVIDIA RTX 5090 32GB (~1.5GB VRAM used)
- Optimizer: AdamW (lr=3e-4, weight_decay=0.01)
- Scheduler: Cosine Annealing (eta_min=1e-6)
- Precision: BFloat16 mixed precision
- Data: Moving MNIST (10K) + LibriSpeech-100 WavJEPA features (28.5K) + TinyStories (428M tokens)
- Training time: ~67 minutes (50 epochs)
Files
best.pt— Best checkpoint (epoch 44, total loss 2.2421)model.py— Model architecture (TriModalTransformerBlock)data.py— Dataset loaderstrain.py— Training looptraining_log.json— Full training logconfig.json— Hyperparameters
Paper
arXiv preprint: Hard-coded Modality Routing in Shared-Attention Transformers: A Toy-Scale Empirical Study (forthcoming)
License
MIT
- Downloads last month
- 33
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support