File size: 8,491 Bytes
fe5cd77 b15b053 fe5cd77 b15b053 6d825c1 fe5cd77 6d825c1 fe5cd77 b15b053 6d825c1 fe5cd77 b15b053 fe5cd77 fb228ce fe5cd77 6d825c1 fb228ce 6d825c1 fb228ce 6d825c1 b15b053 fb228ce fe5cd77 b15b053 6d825c1 b15b053 62aeb4a fe5cd77 6d825c1 fe5cd77 1ffc68a fe5cd77 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 | ---
language:
- en
license: apache-2.0
tags:
- multimodal
- embedding
- matryoshka
- trimodal
- image-text-audio
- retrieval
- cross-modal
- edge
- rag
library_name: safetensors
pipeline_tag: feature-extraction
base_model:
- google/embeddinggemma-300m
datasets:
- custom
---
# TEG-421M β Trimodal Embeddings Gemma
**TEG** (Trimodal Embeddings Gemma) maps image and audio into the same embedding space as text, enabling cross-modal retrieval with a single vector index. All three modalities share a unified 768-dim space via [Google's embeddinggemma-300M](https://huggingface.co/google/embeddinggemma-300m), with full Matryoshka truncation support down to 128 dims.
> Also available in [GGUF format](https://huggingface.co/augmem/teg-421m-gguf) for quantized edge deployment.
## Architecture
TEG combines lightweight edge encoders with deep projection heads that distill into Gemma's embedding space:
```
Text βββ embeddinggemma-300M βββββββββββββββββββββββ 768-dim (L2-normalized)
Image βββ MobileNetV4-Medium (1280-d) βββ DeepProjectionHead βββ 768-dim
Audio βββ EfficientAT mn20_as (1920-d) βββ DeepProjectionHead βββ 768-dim
```
| Component | Architecture | Params | Size |
|---|---|---|---|
| Text encoder | embeddinggemma-300M (bfloat16) | 307.6M | 586.7 MB |
| Image encoder | MobileNetV4-Medium (timm) | 8.5M | 32.4 MB |
| Audio encoder | EfficientAT mn20_as | 18.0M | 68.5 MB |
| Image projection | DeepProjectionHead (1280 β 768) | 42.0M | 160.1 MB |
| Audio projection | DeepProjectionHead (1920 β 768) | 44.6M | 170.1 MB |
| **Total** | | **420.6M** | **1017.8 MB** |
### Projection head detail
Each `DeepProjectionHead` is a residual MLP:
```
Linear(encoder_dim, 4096) β GELU β LayerNorm β Dropout(0.3)
β Linear(4096, 4096) β GELU β LayerNorm β Dropout(0.3) + residual
β Linear(4096, 768)
```
### Matryoshka dimensions
Embeddings can be truncated to `[768, 512, 256, 128]` dimensions while preserving retrieval quality β trained with Matryoshka Representation Learning (MRL) using weights `[1.0, 1.0, 2.0, 4.0]`.
## Benchmarks
All benchmarks run on a single NVIDIA L4 GPU with 5K samples where applicable.
### Cross-modal retrieval β SALT (5K trimodal samples)
| Direction | TEG-421M (421M) | LCO-3B (4.7B) | Nemotron-3B (4.7B) | ImageBind (1.2B) | EBind |
|---|---|---|---|---|---|
| Text β Image R@1 | 0.672 | 0.660 | 0.529 | 0.712 | **0.779** |
| Image β Text R@1 | 0.620 | 0.564 | 0.299 | 0.736 | **0.783** |
| Text β Audio R@1 | **0.113** | 0.042 | 0.018 | 0.038 | 0.047 |
| Audio β Text R@1 | **0.115** | 0.032 | 0.010 | 0.039 | 0.035 |
| Audio β Image R@1 | **0.081** | 0.027 | 0.016 | 0.023 | 0.027 |
| Image β Audio R@1 | **0.083** | 0.034 | 0.018 | 0.025 | 0.032 |
TEG leads all audio cross-modal directions by 2-10x over models that are 3-11x larger. ImageβAudio improved ~40% over v1 via joint cross-modal training. Vision-text trails EBind/ImageBind but uses encoders small enough for edge deployment.
### Audio retrieval β AudioCaps & Clotho
| Benchmark | Direction | TEG-421M | LCO-3B | Nemotron-3B | CLAP-Small | CLAP-Large | ImageBind | EBind |
|---|---|---|---|---|---|---|---|---|
| AudioCaps | AβT R@1 | 0.159 | 0.250 | 0.050 | **0.425** | 0.420 | 0.116 | 0.225 |
| AudioCaps | TβA R@1 | 0.149 | 0.215 | 0.075 | **0.315** | 0.280 | 0.080 | 0.219 |
| Clotho | AβT R@1 | 0.168 | 0.178 | 0.038 | 0.166 | **0.195** | 0.061 | 0.088 |
| Clotho | TβA R@1 | 0.123 | **0.187** | 0.070 | 0.159 | 0.167 | 0.074 | 0.118 |
CLAP models lead on audio-only benchmarks (audio specialists with no image support). Among trimodal models, TEG is competitive with LCO while being 11x smaller.
### Image-text retrieval β MSCOCO & Flickr30k
| Benchmark | Direction | TEG-421M (421M) | EBind (1.78B*) | ImageBind (1.2B) | LCO-3B (4.7B) | Nemotron-3B (4.7B) |
|---|---|---|---|---|---|---|
| MSCOCO 5K | IβT R@1 | 0.248 | **0.743** | 0.658 | 0.533 | 0.225 |
| MSCOCO 5K | TβI R@1 | 0.180 | **0.559** | 0.490 | 0.469 | 0.334 |
| MSCOCO 5K | IβT R@10 | 0.622 | **0.948** | 0.918 | 0.784 | 0.630 |
| Flickr30k | IβT R@1 | 0.498 | β | β | **0.840** | 0.419 |
| Flickr30k | TβI R@1 | 0.358 | β | β | **0.765** | 0.563 |
TEG's image-text retrieval trades accuracy for edge deployability β MobileNetV4-Medium is ~100x smaller than the ViT-H/ViT-L encoders used by competitors. On MSCOCO, TEG outperforms Nemotron-3B on IβT despite being 11x smaller.
### Zero-shot classification β ESC-50
| Model | Params | Accuracy |
|---|---|---|
| CLAP-Large | 67.8M | **0.905** |
| LCO-3B | 4.7B | 0.853 |
| TEG-421M | 421M | 0.820 |
| EBind | 1.78B* | 0.770 |
| CLAP-Small | 27.5M | 0.751 |
| Nemotron-3B | 4.7B | 0.727 |
| ImageBind | 1.2B | 0.664 |
## Usage
### Loading components
```python
from safetensors.torch import load_file
# Load entire model
tensors = load_file("teg-421m.safetensors")
# Extract components by prefix
gemma_sd = {k.removeprefix("gemma."): v for k, v in tensors.items() if k.startswith("gemma.")}
image_enc_sd = {k.removeprefix("image_encoder."): v for k, v in tensors.items() if k.startswith("image_encoder.")}
audio_enc_sd = {k.removeprefix("audio_encoder."): v for k, v in tensors.items() if k.startswith("audio_encoder.")}
image_proj_sd = {k.removeprefix("image_projection."): v for k, v in tensors.items() if k.startswith("image_projection.")}
audio_proj_sd = {k.removeprefix("audio_projection."): v for k, v in tensors.items() if k.startswith("audio_projection.")}
```
### Reading metadata
```python
from safetensors import safe_open
with safe_open("teg-421m.safetensors", framework="pt") as f:
metadata = f.metadata()
print(metadata)
# Keys: format, version, text_model, embed_dim, image_encoder_name,
# image_encoder_dim, audio_encoder_name, audio_encoder_dim,
# audio_sample_rate, matryoshka_dims, total_parameters, ...
```
### Matryoshka truncation
```python
import torch.nn.functional as F
# Full 768-dim embedding
embedding = model(input) # (N, 768)
# Truncate to 256-dim and re-normalize
embedding_256 = F.normalize(embedding[:, :256], dim=-1)
```
## File layout
```
teg-421m.safetensors # All components in one file (~1 GB)
```
### Tensor key prefixes
| Prefix | Component | Tensors |
|---|---|---|
| `gemma.*` | embeddinggemma-300M (bfloat16) | 316 |
| `image_encoder.*` | MobileNetV4-Medium | 462 |
| `audio_encoder.*` | EfficientAT mn20_as | 312 |
| `image_projection.*` | Deep projection head | 14 |
| `audio_projection.*` | Deep projection head | 14 |
## Training
- **Loss**: InfoNCE (contrastive) with Matryoshka Representation Learning
- **Data**: ~4.8M synthetically generated trimodal triplets (text, image, audio)
- **Hardware**: 2x NVIDIA L4 GPUs
- **Optimizer**: AdamW, lr=1e-3 (projections), weight decay=1e-4
- **Epochs**: ~22 (early stopping on validation recall)
- **Projection heads only** β encoders and Gemma are frozen during training
### Design decisions
- **Frozen source encoders**: MobileNetV4 and EfficientAT are kept frozen; only projection heads are trained via distillation into Gemma's space
- **Deep projection heads**: Residual MLPs with dropout outperformed shallow 2-layer heads, especially for audio
- **Matryoshka weighting**: Higher weight on smaller dimensions (4x at 128-dim) ensures quality at aggressive truncation levels
- **Edge-first**: Source encoders chosen for edge deployment β MobileNetV4-Medium and EfficientAT mn20 can run on devices like Raspberry Pi 5
*\*EBind's [HuggingFace checkpoint](https://huggingface.co/encord-team/ebind-full) is 8.93M parameters (bridge heads only), but inference requires frozen backbones (SigLIP ViT-L, CLAP HTSAT, text encoder) totaling 1.78B loaded parameters as measured by our benchmark harness.*
## Limitations
- Audio retrieval lags behind specialist models like CLAP on audio-only benchmarks
- Image-text retrieval trades some accuracy vs larger vision encoders (SigLIP, CLIP ViT-L) for edge deployability
- Trained primarily on synthetic data β real-world distribution shifts may affect performance
- Text modality requires the full embeddinggemma-300M model (307M params, bfloat16)
## Links
- **Website**: [augmem.ai](https://augmem.ai)
- **GitHub**: [github.com/augmem](https://github.com/augmem)
- **GGUF variant**: [augmem/teg-421m-gguf](https://huggingface.co/augmem/teg-421m-gguf)
## License
Apache 2.0
|