Tri-Modal Transformer with Hard-coded Modality Routing (Toy-Scale, 12.1M params)

A toy-scale (12.1M parameters) multimodal Transformer that simultaneously processes Vision, Audio, and Text using Shared Attention + Hard-coded Modality-Specific FFN routing.

Architecture

  • Shared Multi-Head Attention across all modalities (cross-modal interaction)
  • 3 Separated FFN experts: FFN_vision, FFN_audio, FFN_text (hard-coded routing by modality type, no learnable gating)
  • Asymmetric attention mask: Vision↔Audio bidirectional, Text→Vision/Audio allowed, Vision/Audio→Text blocked, Text internal causal

Key Results (50 epochs)

Metric 3-way (V+A+T) 2-way (V+T) 2-way (A+T)
Vision MSE 0.0397 0.0388
Audio MSE 0.0238 0.0227
Text CE 2.1804 2.2026 2.0011

No cross-modality interference: Adding a third modality does not degrade individual modality performance (<5% difference).

Ablation: FFN Routing Strategies (Phase 1, Vision+Text)

Routing Strategy Vision MSE Text CE Params
Separated FFN (hard-coded, ours) 0.0388 2.2026 10.1M
Soft MoE (Top-1 Gating) 0.0408 2.6232 10.1M
Shared FFN (no routing) 0.0370 2.5434 9.0M

Hard-coded routing outperforms learnable routing: Soft MoE (Switch Transformer-style Top-1 gating with load balancing) does not outperform hard-coded modality separation on Text CE. The router must learn what modality-based separation provides structurally for free.

Model Config

CONFIG = {
    "d_model": 256,
    "n_heads": 4,
    "ffn_dim": 512,
    "n_layers": 6,
    "vocab_size": 10000,
    "patch_size": 16,
    "max_seq_len": 512,
    "dropout": 0.1,
    "audio_feat_dim": 768,
    "audio_max_tokens": 200,
}

Training Details

  • GPU: NVIDIA RTX 5090 32GB (~1.5GB VRAM used)
  • Optimizer: AdamW (lr=3e-4, weight_decay=0.01)
  • Scheduler: Cosine Annealing (eta_min=1e-6)
  • Precision: BFloat16 mixed precision
  • Data: Moving MNIST (10K) + LibriSpeech-100 WavJEPA features (28.5K) + TinyStories (428M tokens)
  • Training time: ~67 minutes (50 epochs)

Files

  • best.pt — Best checkpoint (epoch 44, total loss 2.2421)
  • model.py — Model architecture (TriModalTransformerBlock)
  • data.py — Dataset loaders
  • train.py — Training loop
  • training_log.json — Full training log
  • config.json — Hyperparameters

Paper

arXiv preprint: Hard-coded Modality Routing in Shared-Attention Transformers: A Toy-Scale Empirical Study (forthcoming)

License

MIT

Downloads last month
33
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support