File size: 8,491 Bytes
fe5cd77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b15b053
 
fe5cd77
 
b15b053
 
6d825c1
 
 
 
 
 
fe5cd77
6d825c1
fe5cd77
 
 
b15b053
 
6d825c1
 
 
 
fe5cd77
b15b053
fe5cd77
fb228ce
fe5cd77
6d825c1
fb228ce
6d825c1
 
fb228ce
6d825c1
 
b15b053
fb228ce
fe5cd77
 
 
b15b053
 
 
 
6d825c1
 
b15b053
 
 
 
62aeb4a
fe5cd77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d825c1
 
fe5cd77
 
 
 
 
 
 
1ffc68a
 
 
 
 
 
fe5cd77
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
---
language:
- en
license: apache-2.0
tags:
- multimodal
- embedding
- matryoshka
- trimodal
- image-text-audio
- retrieval
- cross-modal
- edge
- rag
library_name: safetensors
pipeline_tag: feature-extraction
base_model:
- google/embeddinggemma-300m
datasets:
- custom
---

# TEG-421M β€” Trimodal Embeddings Gemma

**TEG** (Trimodal Embeddings Gemma) maps image and audio into the same embedding space as text, enabling cross-modal retrieval with a single vector index. All three modalities share a unified 768-dim space via [Google's embeddinggemma-300M](https://huggingface.co/google/embeddinggemma-300m), with full Matryoshka truncation support down to 128 dims.

> Also available in [GGUF format](https://huggingface.co/augmem/teg-421m-gguf) for quantized edge deployment.

## Architecture

TEG combines lightweight edge encoders with deep projection heads that distill into Gemma's embedding space:

```
Text  ──→ embeddinggemma-300M ──────────────────────→ 768-dim (L2-normalized)
Image ──→ MobileNetV4-Medium (1280-d) ──→ DeepProjectionHead ──→ 768-dim
Audio ──→ EfficientAT mn20_as (1920-d) ──→ DeepProjectionHead ──→ 768-dim
```

| Component | Architecture | Params | Size |
|---|---|---|---|
| Text encoder | embeddinggemma-300M (bfloat16) | 307.6M | 586.7 MB |
| Image encoder | MobileNetV4-Medium (timm) | 8.5M | 32.4 MB |
| Audio encoder | EfficientAT mn20_as | 18.0M | 68.5 MB |
| Image projection | DeepProjectionHead (1280 β†’ 768) | 42.0M | 160.1 MB |
| Audio projection | DeepProjectionHead (1920 β†’ 768) | 44.6M | 170.1 MB |
| **Total** | | **420.6M** | **1017.8 MB** |

### Projection head detail

Each `DeepProjectionHead` is a residual MLP:

```
Linear(encoder_dim, 4096) β†’ GELU β†’ LayerNorm β†’ Dropout(0.3)
  β†’ Linear(4096, 4096) β†’ GELU β†’ LayerNorm β†’ Dropout(0.3) + residual
  β†’ Linear(4096, 768)
```

### Matryoshka dimensions

Embeddings can be truncated to `[768, 512, 256, 128]` dimensions while preserving retrieval quality β€” trained with Matryoshka Representation Learning (MRL) using weights `[1.0, 1.0, 2.0, 4.0]`.

## Benchmarks

All benchmarks run on a single NVIDIA L4 GPU with 5K samples where applicable.

### Cross-modal retrieval β€” SALT (5K trimodal samples)

| Direction | TEG-421M (421M) | LCO-3B (4.7B) | Nemotron-3B (4.7B) | ImageBind (1.2B) | EBind |
|---|---|---|---|---|---|
| Text β†’ Image R@1 | 0.672 | 0.660 | 0.529 | 0.712 | **0.779** |
| Image β†’ Text R@1 | 0.620 | 0.564 | 0.299 | 0.736 | **0.783** |
| Text β†’ Audio R@1 | **0.113** | 0.042 | 0.018 | 0.038 | 0.047 |
| Audio β†’ Text R@1 | **0.115** | 0.032 | 0.010 | 0.039 | 0.035 |
| Audio β†’ Image R@1 | **0.081** | 0.027 | 0.016 | 0.023 | 0.027 |
| Image β†’ Audio R@1 | **0.083** | 0.034 | 0.018 | 0.025 | 0.032 |

TEG leads all audio cross-modal directions by 2-10x over models that are 3-11x larger. Image↔Audio improved ~40% over v1 via joint cross-modal training. Vision-text trails EBind/ImageBind but uses encoders small enough for edge deployment.

### Audio retrieval β€” AudioCaps & Clotho

| Benchmark | Direction | TEG-421M | LCO-3B | Nemotron-3B | CLAP-Small | CLAP-Large | ImageBind | EBind |
|---|---|---|---|---|---|---|---|---|
| AudioCaps | A→T R@1 | 0.159 | 0.250 | 0.050 | **0.425** | 0.420 | 0.116 | 0.225 |
| AudioCaps | T→A R@1 | 0.149 | 0.215 | 0.075 | **0.315** | 0.280 | 0.080 | 0.219 |
| Clotho | A→T R@1 | 0.168 | 0.178 | 0.038 | 0.166 | **0.195** | 0.061 | 0.088 |
| Clotho | T→A R@1 | 0.123 | **0.187** | 0.070 | 0.159 | 0.167 | 0.074 | 0.118 |

CLAP models lead on audio-only benchmarks (audio specialists with no image support). Among trimodal models, TEG is competitive with LCO while being 11x smaller.

### Image-text retrieval β€” MSCOCO & Flickr30k

| Benchmark | Direction | TEG-421M (421M) | EBind (1.78B*) | ImageBind (1.2B) | LCO-3B (4.7B) | Nemotron-3B (4.7B) |
|---|---|---|---|---|---|---|
| MSCOCO 5K | I→T R@1 | 0.248 | **0.743** | 0.658 | 0.533 | 0.225 |
| MSCOCO 5K | T→I R@1 | 0.180 | **0.559** | 0.490 | 0.469 | 0.334 |
| MSCOCO 5K | I→T R@10 | 0.622 | **0.948** | 0.918 | 0.784 | 0.630 |
| Flickr30k | I→T R@1 | 0.498 | — | — | **0.840** | 0.419 |
| Flickr30k | T→I R@1 | 0.358 | — | — | **0.765** | 0.563 |

TEG's image-text retrieval trades accuracy for edge deployability — MobileNetV4-Medium is ~100x smaller than the ViT-H/ViT-L encoders used by competitors. On MSCOCO, TEG outperforms Nemotron-3B on I→T despite being 11x smaller.

### Zero-shot classification β€” ESC-50

| Model | Params | Accuracy |
|---|---|---|
| CLAP-Large | 67.8M | **0.905** |
| LCO-3B | 4.7B | 0.853 |
| TEG-421M | 421M | 0.820 |
| EBind | 1.78B* | 0.770 |
| CLAP-Small | 27.5M | 0.751 |
| Nemotron-3B | 4.7B | 0.727 |
| ImageBind | 1.2B | 0.664 |


## Usage

### Loading components

```python
from safetensors.torch import load_file

# Load entire model
tensors = load_file("teg-421m.safetensors")

# Extract components by prefix
gemma_sd = {k.removeprefix("gemma."): v for k, v in tensors.items() if k.startswith("gemma.")}
image_enc_sd = {k.removeprefix("image_encoder."): v for k, v in tensors.items() if k.startswith("image_encoder.")}
audio_enc_sd = {k.removeprefix("audio_encoder."): v for k, v in tensors.items() if k.startswith("audio_encoder.")}
image_proj_sd = {k.removeprefix("image_projection."): v for k, v in tensors.items() if k.startswith("image_projection.")}
audio_proj_sd = {k.removeprefix("audio_projection."): v for k, v in tensors.items() if k.startswith("audio_projection.")}
```

### Reading metadata

```python
from safetensors import safe_open

with safe_open("teg-421m.safetensors", framework="pt") as f:
    metadata = f.metadata()
    print(metadata)
    # Keys: format, version, text_model, embed_dim, image_encoder_name,
    #        image_encoder_dim, audio_encoder_name, audio_encoder_dim,
    #        audio_sample_rate, matryoshka_dims, total_parameters, ...
```

### Matryoshka truncation

```python
import torch.nn.functional as F

# Full 768-dim embedding
embedding = model(input)  # (N, 768)

# Truncate to 256-dim and re-normalize
embedding_256 = F.normalize(embedding[:, :256], dim=-1)
```

## File layout

```
teg-421m.safetensors     # All components in one file (~1 GB)
```

### Tensor key prefixes

| Prefix | Component | Tensors |
|---|---|---|
| `gemma.*` | embeddinggemma-300M (bfloat16) | 316 |
| `image_encoder.*` | MobileNetV4-Medium | 462 |
| `audio_encoder.*` | EfficientAT mn20_as | 312 |
| `image_projection.*` | Deep projection head | 14 |
| `audio_projection.*` | Deep projection head | 14 |

## Training

- **Loss**: InfoNCE (contrastive) with Matryoshka Representation Learning
- **Data**: ~4.8M synthetically generated trimodal triplets (text, image, audio)
- **Hardware**: 2x NVIDIA L4 GPUs
- **Optimizer**: AdamW, lr=1e-3 (projections), weight decay=1e-4
- **Epochs**: ~22 (early stopping on validation recall)
- **Projection heads only** β€” encoders and Gemma are frozen during training

### Design decisions

- **Frozen source encoders**: MobileNetV4 and EfficientAT are kept frozen; only projection heads are trained via distillation into Gemma's space
- **Deep projection heads**: Residual MLPs with dropout outperformed shallow 2-layer heads, especially for audio
- **Matryoshka weighting**: Higher weight on smaller dimensions (4x at 128-dim) ensures quality at aggressive truncation levels
- **Edge-first**: Source encoders chosen for edge deployment β€” MobileNetV4-Medium and EfficientAT mn20 can run on devices like Raspberry Pi 5

*\*EBind's [HuggingFace checkpoint](https://huggingface.co/encord-team/ebind-full) is 8.93M parameters (bridge heads only), but inference requires frozen backbones (SigLIP ViT-L, CLAP HTSAT, text encoder) totaling 1.78B loaded parameters as measured by our benchmark harness.*

## Limitations

- Audio retrieval lags behind specialist models like CLAP on audio-only benchmarks
- Image-text retrieval trades some accuracy vs larger vision encoders (SigLIP, CLIP ViT-L) for edge deployability
- Trained primarily on synthetic data β€” real-world distribution shifts may affect performance
- Text modality requires the full embeddinggemma-300M model (307M params, bfloat16)

## Links

- **Website**: [augmem.ai](https://augmem.ai)
- **GitHub**: [github.com/augmem](https://github.com/augmem)
- **GGUF variant**: [augmem/teg-421m-gguf](https://huggingface.co/augmem/teg-421m-gguf)

## License

Apache 2.0