File size: 7,995 Bytes
3bd07c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdca31d
3bd07c9
fdca31d
3bd07c9
 
 
fdca31d
3bd07c9
 
 
fdca31d
3bd07c9
 
fc0ae36
 
 
3bd07c9
 
fc0ae36
 
3bd07c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdca31d
3bd07c9
 
 
 
 
 
 
 
 
 
fdca31d
3bd07c9
 
 
 
 
 
fdca31d
3bd07c9
 
 
fdca31d
3bd07c9
fdb10ed
 
3bd07c9
 
 
 
 
 
 
 
fdca31d
3bd07c9
 
 
 
 
 
 
 
 
fdca31d
3bd07c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdca31d
3bd07c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdca31d
3bd07c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
---
language:
- en
license: apache-2.0
tags:
- multimodal
- embedding
- matryoshka
- trimodal
- image-text-audio
- retrieval
- cross-modal
- edge
- rag
library_name: safetensors
pipeline_tag: feature-extraction
datasets:
- custom
---

# AIT-75M β€” Audio, Image, Text Embeddings

**AIT-75M** maps image, audio, and text into a shared 1280-dim embedding space, enabling cross-modal retrieval with a single vector index. All three modalities share a unified space with full Matryoshka truncation support down to 128 dims.

Built for edge deployment β€” the entire model runs on a Raspberry Pi 5.

> Also available in [GGUF format](https://huggingface.co/augmem/AIT-75M-GGUF) for quantized edge deployment (114 MB at Q8_0).

## Architecture

AIT-75M uses lightweight edge encoders with learned projection heads that expand through a 1920-dim hidden layer before projecting into a shared 1280-dim embedding space:

```
Text  --> LEAF-IR (768-d) -----------> DeepProjectionHead (768 -> 1920 -> 1280)
Image --> MobileNetV4-Medium (1280-d) --> DeepProjectionHead (1280 -> 1920 -> 1280)
Audio --> EfficientAT mn20_as (1920-d) --> DeepProjectionHead (1920 -> 1920 -> 1280)
```

All outputs are L2-normalized into the shared 1280-dim space for cross-modal cosine similarity.

| Component | Architecture | Params | Size |
|---|---|---|---|
| Text encoder | LEAF-IR (MongoDB/mdbr-leaf-ir) | 22.7M | 87.2 MB |
| Image encoder | MobileNetV4-Medium (timm) | 8.4M | 32.4 MB |
| Audio encoder | EfficientAT mn20_as | 17.9M | 68.5 MB |
| Image projection | DeepProjectionHead (1280 -> 1920 -> 1280) | 8.6M | 32.9 MB |
| Audio projection | DeepProjectionHead (1920 -> 1920 -> 1280) | 9.8M | 37.5 MB |
| Text projection | DeepProjectionHead (768 -> 1920 -> 1280) | 7.6M | 29.1 MB |
| **Total** | | **75.2M** | **287.7 MB** |

### Projection head detail

Each `DeepProjectionHead` is a depth-1 residual MLP with Matryoshka-aware training:

```
Linear(encoder_dim, 1920) -> GELU -> LayerNorm -> Dropout(0.2)
  -> Linear(1920, 1920) -> GELU -> LayerNorm -> Dropout(0.2) + residual
  -> Linear(1920, 1280)
```

### Matryoshka dimensions

Embeddings can be truncated to `[1280, 768, 512, 256, 128]` dimensions while preserving retrieval quality β€” trained with Matryoshka Representation Learning (MRL).

## Benchmarks

All benchmarks run on a single NVIDIA L4 GPU with 5K SALT samples.

### Cross-modal retrieval β€” SALT (5K trimodal samples)

| Direction | AIT-75M (75M) | TEG-421M (421M) | ImageBind (1.2B) | EBind (1.78B*) |
|---|---|---|---|---|
| Image -> Text R@1 | 0.615 | 0.620 | 0.736 | **0.783** |
| Text -> Image R@1 | 0.614 | 0.672 | 0.712 | **0.779** |
| Text -> Audio R@1 | **0.103** | 0.113 | 0.038 | 0.047 |
| Audio -> Text R@1 | 0.082 | **0.115** | 0.039 | 0.035 |
| Image -> Audio R@1 | **0.062** | 0.083 | 0.023 | 0.027 |
| Audio -> Image R@1 | **0.063** | 0.081 | 0.025 | 0.032 |

### Audio retrieval β€” AudioCaps & Clotho

| Benchmark | Direction | AIT-75M | CLAP-Large | ImageBind | EBind |
|---|---|---|---|---|---|
| AudioCaps | A->T R@1 | 0.210 | **0.420** | 0.116 | 0.225 |
| AudioCaps | T->A R@1 | 0.148 | **0.280** | 0.080 | 0.219 |
| Clotho | A->T R@1 | **0.208** | 0.195 | 0.061 | 0.088 |
| Clotho | T->A R@1 | 0.172 | **0.167** | 0.074 | 0.118 |

AIT-75M beats Clotho A->T R@1 for all models including CLAP-Large, while being fully trimodal.

### Image-text retrieval β€” MSCOCO & Flickr30k

| Benchmark | Direction | AIT-75M (75M) | EBind (1.78B*) | ImageBind (1.2B) |
|---|---|---|---|---|
| Flickr30k | I->T R@1 | 0.478 | **0.951** | 0.918 |
| Flickr30k | T->I R@1 | 0.303 | **0.853** | 0.766 |
| MSCOCO 5K | I->T R@1 | 0.320 | **0.743** | 0.658 |
| MSCOCO 5K | T->I R@1 | 0.208 | **0.559** | 0.490 |

### Zero-shot classification β€” ESC-50

| Model | Params | Accuracy |
|---|---|---|
| CLAP-Large | 67.8M | **90.5%** |
| AIT-75M | 75M | 93.2% |
| EBind | 1.78B* | 77.0% |
| ImageBind | 1.2B | 66.4% |

**#1 on ESC-50** (93.2%) at 75M params β€” beats CLAP-Large (90.5%) while being trimodal.

### Text retrieval β€” MTEB (NDCG@10)

Text-text retrieval quality in the shared embedding space, measured on MTEB retrieval tasks:

| Task | AIT-75M | Raw LEAF-IR | Recovery |
|---|---|---|---|
| ArguAna | 0.544 | 0.594 | 92% |
| CQADupstackGaming | 0.506 | 0.607 | 83% |
| CQADupstackUnix | 0.355 | 0.428 | 83% |
| FEVERHardNegatives | 0.551 | 0.863 | 64% |
| HotpotQAHardNegatives | 0.531 | 0.700 | 76% |
| FiQA2018 | 0.292 | 0.392 | 74% |
| ClimateFEVER | 0.215 | 0.353 | 61% |
| SCIDOCS | 0.153 | 0.198 | 77% |
| TRECCOVID | 0.474 | 0.820 | 58% |

The text projection head recovers 58-92% of raw LEAF-IR's retrieval quality while mapping into the cross-modal shared space.

## Usage

### Loading components

```python
from safetensors.torch import load_file

# Load entire model
tensors = load_file("AIT-75M.safetensors")

# Extract components by prefix
text_enc_sd = {k.removeprefix("text_encoder."): v for k, v in tensors.items() if k.startswith("text_encoder.")}
image_enc_sd = {k.removeprefix("image_encoder."): v for k, v in tensors.items() if k.startswith("image_encoder.")}
audio_enc_sd = {k.removeprefix("audio_encoder."): v for k, v in tensors.items() if k.startswith("audio_encoder.")}
image_proj_sd = {k.removeprefix("image_projection."): v for k, v in tensors.items() if k.startswith("image_projection.")}
audio_proj_sd = {k.removeprefix("audio_projection."): v for k, v in tensors.items() if k.startswith("audio_projection.")}
text_proj_sd = {k.removeprefix("text_projection."): v for k, v in tensors.items() if k.startswith("text_projection.")}
```

### Matryoshka truncation

```python
import torch.nn.functional as F

# Full 1280-dim embedding
embedding = model(input)  # (N, 1280)

# Truncate to 256-dim and re-normalize
embedding_256 = F.normalize(embedding[:, :256], dim=-1)
```

## File layout

```
AIT-75M.safetensors     # All components in one file (~288 MB)
```

### Tensor key prefixes

| Prefix | Component | Tensors |
|---|---|---|
| `text_encoder.*` | LEAF-IR (float32) | 103 |
| `image_encoder.*` | MobileNetV4-Medium | 462 |
| `audio_encoder.*` | EfficientAT mn20_as | 312 |
| `image_projection.*` | Projection head | 10 |
| `audio_projection.*` | Projection head | 10 |
| `text_projection.*` | Projection head | 10 |

## Training

- **Loss**: InfoNCE (contrastive) with Matryoshka Representation Learning
- **Data**: ~2.2M synthetically generated trimodal triplets (WordNet) + 200K MSCOCO img+txt + 262K WavCaps aud+txt + 1.5M Nomic text pairs
- **Hardware**: 2x NVIDIA L4 GPUs
- **Text retrieval fine-tune**: Phase 1 warm start from d20 checkpoint, text-head-only with frozen image/audio heads, Nomic supervised text pairs mixed at lambda_tt=0.25
- **Optimizer**: AdamW, lr=1e-3, weight decay=1e-4, cosine scheduler
- **Epochs**: 7 (text fine-tune from pre-trained trimodal base)
- **Projection heads only** β€” source encoders are frozen during training

### Design decisions

- **3-head shared space**: All modalities project into a learned 1280-dim space (image-native dimension) instead of targeting a pre-existing text encoder space
- **LEAF-IR text encoder**: 23M-param retrieval-optimized text encoder replaces 300M Gemma, enabling fully edge-deployable text inference
- **Frozen source encoders**: MobileNetV4, EfficientAT, and LEAF-IR are kept frozen; only projection heads are trained
- **Text retrieval fine-tune**: Nomic supervised text pairs (1.5M) mixed into trimodal training to improve text-text retrieval while preserving cross-modal alignment
- **Edge-first**: All source encoders can run on devices like Raspberry Pi 5

## Limitations

- Audio retrieval lags behind specialist models like CLAP on audio-only benchmarks
- Image-text retrieval trades accuracy vs larger vision encoders for edge deployability
- Text retrieval recovers 58-92% of raw LEAF-IR quality (gap is domain-dependent)

## Links

- **Website**: [augmem.ai](https://augmem.ai)
- **GitHub**: [github.com/augmem](https://github.com/augmem)

## License

Apache 2.0