teg-421m / README.md
gcoderw's picture
Update benchmarks for v2 (joint cross-modal training)
6d825c1 verified
---
language:
- en
license: apache-2.0
tags:
- multimodal
- embedding
- matryoshka
- trimodal
- image-text-audio
- retrieval
- cross-modal
- edge
- rag
library_name: safetensors
pipeline_tag: feature-extraction
base_model:
- google/embeddinggemma-300m
datasets:
- custom
---
# TEG-421M — Trimodal Embeddings Gemma
**TEG** (Trimodal Embeddings Gemma) maps image and audio into the same embedding space as text, enabling cross-modal retrieval with a single vector index. All three modalities share a unified 768-dim space via [Google's embeddinggemma-300M](https://huggingface.co/google/embeddinggemma-300m), with full Matryoshka truncation support down to 128 dims.
> Also available in [GGUF format](https://huggingface.co/augmem/teg-421m-gguf) for quantized edge deployment.
## Architecture
TEG combines lightweight edge encoders with deep projection heads that distill into Gemma's embedding space:
```
Text ──→ embeddinggemma-300M ──────────────────────→ 768-dim (L2-normalized)
Image ──→ MobileNetV4-Medium (1280-d) ──→ DeepProjectionHead ──→ 768-dim
Audio ──→ EfficientAT mn20_as (1920-d) ──→ DeepProjectionHead ──→ 768-dim
```
| Component | Architecture | Params | Size |
|---|---|---|---|
| Text encoder | embeddinggemma-300M (bfloat16) | 307.6M | 586.7 MB |
| Image encoder | MobileNetV4-Medium (timm) | 8.5M | 32.4 MB |
| Audio encoder | EfficientAT mn20_as | 18.0M | 68.5 MB |
| Image projection | DeepProjectionHead (1280 → 768) | 42.0M | 160.1 MB |
| Audio projection | DeepProjectionHead (1920 → 768) | 44.6M | 170.1 MB |
| **Total** | | **420.6M** | **1017.8 MB** |
### Projection head detail
Each `DeepProjectionHead` is a residual MLP:
```
Linear(encoder_dim, 4096) → GELU → LayerNorm → Dropout(0.3)
→ Linear(4096, 4096) → GELU → LayerNorm → Dropout(0.3) + residual
→ Linear(4096, 768)
```
### Matryoshka dimensions
Embeddings can be truncated to `[768, 512, 256, 128]` dimensions while preserving retrieval quality — trained with Matryoshka Representation Learning (MRL) using weights `[1.0, 1.0, 2.0, 4.0]`.
## Benchmarks
All benchmarks run on a single NVIDIA L4 GPU with 5K samples where applicable.
### Cross-modal retrieval — SALT (5K trimodal samples)
| Direction | TEG-421M (421M) | LCO-3B (4.7B) | Nemotron-3B (4.7B) | ImageBind (1.2B) | EBind |
|---|---|---|---|---|---|
| Text → Image R@1 | 0.672 | 0.660 | 0.529 | 0.712 | **0.779** |
| Image → Text R@1 | 0.620 | 0.564 | 0.299 | 0.736 | **0.783** |
| Text → Audio R@1 | **0.113** | 0.042 | 0.018 | 0.038 | 0.047 |
| Audio → Text R@1 | **0.115** | 0.032 | 0.010 | 0.039 | 0.035 |
| Audio → Image R@1 | **0.081** | 0.027 | 0.016 | 0.023 | 0.027 |
| Image → Audio R@1 | **0.083** | 0.034 | 0.018 | 0.025 | 0.032 |
TEG leads all audio cross-modal directions by 2-10x over models that are 3-11x larger. Image↔Audio improved ~40% over v1 via joint cross-modal training. Vision-text trails EBind/ImageBind but uses encoders small enough for edge deployment.
### Audio retrieval — AudioCaps & Clotho
| Benchmark | Direction | TEG-421M | LCO-3B | Nemotron-3B | CLAP-Small | CLAP-Large | ImageBind | EBind |
|---|---|---|---|---|---|---|---|---|
| AudioCaps | A→T R@1 | 0.159 | 0.250 | 0.050 | **0.425** | 0.420 | 0.116 | 0.225 |
| AudioCaps | T→A R@1 | 0.149 | 0.215 | 0.075 | **0.315** | 0.280 | 0.080 | 0.219 |
| Clotho | A→T R@1 | 0.168 | 0.178 | 0.038 | 0.166 | **0.195** | 0.061 | 0.088 |
| Clotho | T→A R@1 | 0.123 | **0.187** | 0.070 | 0.159 | 0.167 | 0.074 | 0.118 |
CLAP models lead on audio-only benchmarks (audio specialists with no image support). Among trimodal models, TEG is competitive with LCO while being 11x smaller.
### Image-text retrieval — MSCOCO & Flickr30k
| Benchmark | Direction | TEG-421M (421M) | EBind (1.78B*) | ImageBind (1.2B) | LCO-3B (4.7B) | Nemotron-3B (4.7B) |
|---|---|---|---|---|---|---|
| MSCOCO 5K | I→T R@1 | 0.248 | **0.743** | 0.658 | 0.533 | 0.225 |
| MSCOCO 5K | T→I R@1 | 0.180 | **0.559** | 0.490 | 0.469 | 0.334 |
| MSCOCO 5K | I→T R@10 | 0.622 | **0.948** | 0.918 | 0.784 | 0.630 |
| Flickr30k | I→T R@1 | 0.498 | — | — | **0.840** | 0.419 |
| Flickr30k | T→I R@1 | 0.358 | — | — | **0.765** | 0.563 |
TEG's image-text retrieval trades accuracy for edge deployability — MobileNetV4-Medium is ~100x smaller than the ViT-H/ViT-L encoders used by competitors. On MSCOCO, TEG outperforms Nemotron-3B on I→T despite being 11x smaller.
### Zero-shot classification — ESC-50
| Model | Params | Accuracy |
|---|---|---|
| CLAP-Large | 67.8M | **0.905** |
| LCO-3B | 4.7B | 0.853 |
| TEG-421M | 421M | 0.820 |
| EBind | 1.78B* | 0.770 |
| CLAP-Small | 27.5M | 0.751 |
| Nemotron-3B | 4.7B | 0.727 |
| ImageBind | 1.2B | 0.664 |
## Usage
### Loading components
```python
from safetensors.torch import load_file
# Load entire model
tensors = load_file("teg-421m.safetensors")
# Extract components by prefix
gemma_sd = {k.removeprefix("gemma."): v for k, v in tensors.items() if k.startswith("gemma.")}
image_enc_sd = {k.removeprefix("image_encoder."): v for k, v in tensors.items() if k.startswith("image_encoder.")}
audio_enc_sd = {k.removeprefix("audio_encoder."): v for k, v in tensors.items() if k.startswith("audio_encoder.")}
image_proj_sd = {k.removeprefix("image_projection."): v for k, v in tensors.items() if k.startswith("image_projection.")}
audio_proj_sd = {k.removeprefix("audio_projection."): v for k, v in tensors.items() if k.startswith("audio_projection.")}
```
### Reading metadata
```python
from safetensors import safe_open
with safe_open("teg-421m.safetensors", framework="pt") as f:
metadata = f.metadata()
print(metadata)
# Keys: format, version, text_model, embed_dim, image_encoder_name,
# image_encoder_dim, audio_encoder_name, audio_encoder_dim,
# audio_sample_rate, matryoshka_dims, total_parameters, ...
```
### Matryoshka truncation
```python
import torch.nn.functional as F
# Full 768-dim embedding
embedding = model(input) # (N, 768)
# Truncate to 256-dim and re-normalize
embedding_256 = F.normalize(embedding[:, :256], dim=-1)
```
## File layout
```
teg-421m.safetensors # All components in one file (~1 GB)
```
### Tensor key prefixes
| Prefix | Component | Tensors |
|---|---|---|
| `gemma.*` | embeddinggemma-300M (bfloat16) | 316 |
| `image_encoder.*` | MobileNetV4-Medium | 462 |
| `audio_encoder.*` | EfficientAT mn20_as | 312 |
| `image_projection.*` | Deep projection head | 14 |
| `audio_projection.*` | Deep projection head | 14 |
## Training
- **Loss**: InfoNCE (contrastive) with Matryoshka Representation Learning
- **Data**: ~4.8M synthetically generated trimodal triplets (text, image, audio)
- **Hardware**: 2x NVIDIA L4 GPUs
- **Optimizer**: AdamW, lr=1e-3 (projections), weight decay=1e-4
- **Epochs**: ~22 (early stopping on validation recall)
- **Projection heads only** — encoders and Gemma are frozen during training
### Design decisions
- **Frozen source encoders**: MobileNetV4 and EfficientAT are kept frozen; only projection heads are trained via distillation into Gemma's space
- **Deep projection heads**: Residual MLPs with dropout outperformed shallow 2-layer heads, especially for audio
- **Matryoshka weighting**: Higher weight on smaller dimensions (4x at 128-dim) ensures quality at aggressive truncation levels
- **Edge-first**: Source encoders chosen for edge deployment — MobileNetV4-Medium and EfficientAT mn20 can run on devices like Raspberry Pi 5
*\*EBind's [HuggingFace checkpoint](https://huggingface.co/encord-team/ebind-full) is 8.93M parameters (bridge heads only), but inference requires frozen backbones (SigLIP ViT-L, CLAP HTSAT, text encoder) totaling 1.78B loaded parameters as measured by our benchmark harness.*
## Limitations
- Audio retrieval lags behind specialist models like CLAP on audio-only benchmarks
- Image-text retrieval trades some accuracy vs larger vision encoders (SigLIP, CLIP ViT-L) for edge deployability
- Trained primarily on synthetic data — real-world distribution shifts may affect performance
- Text modality requires the full embeddinggemma-300M model (307M params, bfloat16)
## Links
- **Website**: [augmem.ai](https://augmem.ai)
- **GitHub**: [github.com/augmem](https://github.com/augmem)
- **GGUF variant**: [augmem/teg-421m-gguf](https://huggingface.co/augmem/teg-421m-gguf)
## License
Apache 2.0