File size: 7,995 Bytes
3bd07c9 fdca31d 3bd07c9 fdca31d 3bd07c9 fdca31d 3bd07c9 fdca31d 3bd07c9 fc0ae36 3bd07c9 fc0ae36 3bd07c9 fdca31d 3bd07c9 fdca31d 3bd07c9 fdca31d 3bd07c9 fdca31d 3bd07c9 fdb10ed 3bd07c9 fdca31d 3bd07c9 fdca31d 3bd07c9 fdca31d 3bd07c9 fdca31d 3bd07c9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 | ---
language:
- en
license: apache-2.0
tags:
- multimodal
- embedding
- matryoshka
- trimodal
- image-text-audio
- retrieval
- cross-modal
- edge
- rag
library_name: safetensors
pipeline_tag: feature-extraction
datasets:
- custom
---
# AIT-75M β Audio, Image, Text Embeddings
**AIT-75M** maps image, audio, and text into a shared 1280-dim embedding space, enabling cross-modal retrieval with a single vector index. All three modalities share a unified space with full Matryoshka truncation support down to 128 dims.
Built for edge deployment β the entire model runs on a Raspberry Pi 5.
> Also available in [GGUF format](https://huggingface.co/augmem/AIT-75M-GGUF) for quantized edge deployment (114 MB at Q8_0).
## Architecture
AIT-75M uses lightweight edge encoders with learned projection heads that expand through a 1920-dim hidden layer before projecting into a shared 1280-dim embedding space:
```
Text --> LEAF-IR (768-d) -----------> DeepProjectionHead (768 -> 1920 -> 1280)
Image --> MobileNetV4-Medium (1280-d) --> DeepProjectionHead (1280 -> 1920 -> 1280)
Audio --> EfficientAT mn20_as (1920-d) --> DeepProjectionHead (1920 -> 1920 -> 1280)
```
All outputs are L2-normalized into the shared 1280-dim space for cross-modal cosine similarity.
| Component | Architecture | Params | Size |
|---|---|---|---|
| Text encoder | LEAF-IR (MongoDB/mdbr-leaf-ir) | 22.7M | 87.2 MB |
| Image encoder | MobileNetV4-Medium (timm) | 8.4M | 32.4 MB |
| Audio encoder | EfficientAT mn20_as | 17.9M | 68.5 MB |
| Image projection | DeepProjectionHead (1280 -> 1920 -> 1280) | 8.6M | 32.9 MB |
| Audio projection | DeepProjectionHead (1920 -> 1920 -> 1280) | 9.8M | 37.5 MB |
| Text projection | DeepProjectionHead (768 -> 1920 -> 1280) | 7.6M | 29.1 MB |
| **Total** | | **75.2M** | **287.7 MB** |
### Projection head detail
Each `DeepProjectionHead` is a depth-1 residual MLP with Matryoshka-aware training:
```
Linear(encoder_dim, 1920) -> GELU -> LayerNorm -> Dropout(0.2)
-> Linear(1920, 1920) -> GELU -> LayerNorm -> Dropout(0.2) + residual
-> Linear(1920, 1280)
```
### Matryoshka dimensions
Embeddings can be truncated to `[1280, 768, 512, 256, 128]` dimensions while preserving retrieval quality β trained with Matryoshka Representation Learning (MRL).
## Benchmarks
All benchmarks run on a single NVIDIA L4 GPU with 5K SALT samples.
### Cross-modal retrieval β SALT (5K trimodal samples)
| Direction | AIT-75M (75M) | TEG-421M (421M) | ImageBind (1.2B) | EBind (1.78B*) |
|---|---|---|---|---|
| Image -> Text R@1 | 0.615 | 0.620 | 0.736 | **0.783** |
| Text -> Image R@1 | 0.614 | 0.672 | 0.712 | **0.779** |
| Text -> Audio R@1 | **0.103** | 0.113 | 0.038 | 0.047 |
| Audio -> Text R@1 | 0.082 | **0.115** | 0.039 | 0.035 |
| Image -> Audio R@1 | **0.062** | 0.083 | 0.023 | 0.027 |
| Audio -> Image R@1 | **0.063** | 0.081 | 0.025 | 0.032 |
### Audio retrieval β AudioCaps & Clotho
| Benchmark | Direction | AIT-75M | CLAP-Large | ImageBind | EBind |
|---|---|---|---|---|---|
| AudioCaps | A->T R@1 | 0.210 | **0.420** | 0.116 | 0.225 |
| AudioCaps | T->A R@1 | 0.148 | **0.280** | 0.080 | 0.219 |
| Clotho | A->T R@1 | **0.208** | 0.195 | 0.061 | 0.088 |
| Clotho | T->A R@1 | 0.172 | **0.167** | 0.074 | 0.118 |
AIT-75M beats Clotho A->T R@1 for all models including CLAP-Large, while being fully trimodal.
### Image-text retrieval β MSCOCO & Flickr30k
| Benchmark | Direction | AIT-75M (75M) | EBind (1.78B*) | ImageBind (1.2B) |
|---|---|---|---|---|
| Flickr30k | I->T R@1 | 0.478 | **0.951** | 0.918 |
| Flickr30k | T->I R@1 | 0.303 | **0.853** | 0.766 |
| MSCOCO 5K | I->T R@1 | 0.320 | **0.743** | 0.658 |
| MSCOCO 5K | T->I R@1 | 0.208 | **0.559** | 0.490 |
### Zero-shot classification β ESC-50
| Model | Params | Accuracy |
|---|---|---|
| CLAP-Large | 67.8M | **90.5%** |
| AIT-75M | 75M | 93.2% |
| EBind | 1.78B* | 77.0% |
| ImageBind | 1.2B | 66.4% |
**#1 on ESC-50** (93.2%) at 75M params β beats CLAP-Large (90.5%) while being trimodal.
### Text retrieval β MTEB (NDCG@10)
Text-text retrieval quality in the shared embedding space, measured on MTEB retrieval tasks:
| Task | AIT-75M | Raw LEAF-IR | Recovery |
|---|---|---|---|
| ArguAna | 0.544 | 0.594 | 92% |
| CQADupstackGaming | 0.506 | 0.607 | 83% |
| CQADupstackUnix | 0.355 | 0.428 | 83% |
| FEVERHardNegatives | 0.551 | 0.863 | 64% |
| HotpotQAHardNegatives | 0.531 | 0.700 | 76% |
| FiQA2018 | 0.292 | 0.392 | 74% |
| ClimateFEVER | 0.215 | 0.353 | 61% |
| SCIDOCS | 0.153 | 0.198 | 77% |
| TRECCOVID | 0.474 | 0.820 | 58% |
The text projection head recovers 58-92% of raw LEAF-IR's retrieval quality while mapping into the cross-modal shared space.
## Usage
### Loading components
```python
from safetensors.torch import load_file
# Load entire model
tensors = load_file("AIT-75M.safetensors")
# Extract components by prefix
text_enc_sd = {k.removeprefix("text_encoder."): v for k, v in tensors.items() if k.startswith("text_encoder.")}
image_enc_sd = {k.removeprefix("image_encoder."): v for k, v in tensors.items() if k.startswith("image_encoder.")}
audio_enc_sd = {k.removeprefix("audio_encoder."): v for k, v in tensors.items() if k.startswith("audio_encoder.")}
image_proj_sd = {k.removeprefix("image_projection."): v for k, v in tensors.items() if k.startswith("image_projection.")}
audio_proj_sd = {k.removeprefix("audio_projection."): v for k, v in tensors.items() if k.startswith("audio_projection.")}
text_proj_sd = {k.removeprefix("text_projection."): v for k, v in tensors.items() if k.startswith("text_projection.")}
```
### Matryoshka truncation
```python
import torch.nn.functional as F
# Full 1280-dim embedding
embedding = model(input) # (N, 1280)
# Truncate to 256-dim and re-normalize
embedding_256 = F.normalize(embedding[:, :256], dim=-1)
```
## File layout
```
AIT-75M.safetensors # All components in one file (~288 MB)
```
### Tensor key prefixes
| Prefix | Component | Tensors |
|---|---|---|
| `text_encoder.*` | LEAF-IR (float32) | 103 |
| `image_encoder.*` | MobileNetV4-Medium | 462 |
| `audio_encoder.*` | EfficientAT mn20_as | 312 |
| `image_projection.*` | Projection head | 10 |
| `audio_projection.*` | Projection head | 10 |
| `text_projection.*` | Projection head | 10 |
## Training
- **Loss**: InfoNCE (contrastive) with Matryoshka Representation Learning
- **Data**: ~2.2M synthetically generated trimodal triplets (WordNet) + 200K MSCOCO img+txt + 262K WavCaps aud+txt + 1.5M Nomic text pairs
- **Hardware**: 2x NVIDIA L4 GPUs
- **Text retrieval fine-tune**: Phase 1 warm start from d20 checkpoint, text-head-only with frozen image/audio heads, Nomic supervised text pairs mixed at lambda_tt=0.25
- **Optimizer**: AdamW, lr=1e-3, weight decay=1e-4, cosine scheduler
- **Epochs**: 7 (text fine-tune from pre-trained trimodal base)
- **Projection heads only** β source encoders are frozen during training
### Design decisions
- **3-head shared space**: All modalities project into a learned 1280-dim space (image-native dimension) instead of targeting a pre-existing text encoder space
- **LEAF-IR text encoder**: 23M-param retrieval-optimized text encoder replaces 300M Gemma, enabling fully edge-deployable text inference
- **Frozen source encoders**: MobileNetV4, EfficientAT, and LEAF-IR are kept frozen; only projection heads are trained
- **Text retrieval fine-tune**: Nomic supervised text pairs (1.5M) mixed into trimodal training to improve text-text retrieval while preserving cross-modal alignment
- **Edge-first**: All source encoders can run on devices like Raspberry Pi 5
## Limitations
- Audio retrieval lags behind specialist models like CLAP on audio-only benchmarks
- Image-text retrieval trades accuracy vs larger vision encoders for edge deployability
- Text retrieval recovers 58-92% of raw LEAF-IR quality (gap is domain-dependent)
## Links
- **Website**: [augmem.ai](https://augmem.ai)
- **GitHub**: [github.com/augmem](https://github.com/augmem)
## License
Apache 2.0
|