| --- |
| language: |
| - en |
| license: apache-2.0 |
| tags: |
| - multimodal |
| - embedding |
| - matryoshka |
| - trimodal |
| - image-text-audio |
| - retrieval |
| - cross-modal |
| - edge |
| - rag |
| library_name: safetensors |
| pipeline_tag: feature-extraction |
| base_model: |
| - google/embeddinggemma-300m |
| datasets: |
| - custom |
| --- |
| |
| # TEG-421M — Trimodal Embeddings Gemma |
|
|
| **TEG** (Trimodal Embeddings Gemma) maps image and audio into the same embedding space as text, enabling cross-modal retrieval with a single vector index. All three modalities share a unified 768-dim space via [Google's embeddinggemma-300M](https://huggingface.co/google/embeddinggemma-300m), with full Matryoshka truncation support down to 128 dims. |
|
|
| > Also available in [GGUF format](https://huggingface.co/augmem/teg-421m-gguf) for quantized edge deployment. |
|
|
| ## Architecture |
|
|
| TEG combines lightweight edge encoders with deep projection heads that distill into Gemma's embedding space: |
|
|
| ``` |
| Text ──→ embeddinggemma-300M ──────────────────────→ 768-dim (L2-normalized) |
| Image ──→ MobileNetV4-Medium (1280-d) ──→ DeepProjectionHead ──→ 768-dim |
| Audio ──→ EfficientAT mn20_as (1920-d) ──→ DeepProjectionHead ──→ 768-dim |
| ``` |
|
|
| | Component | Architecture | Params | Size | |
| |---|---|---|---| |
| | Text encoder | embeddinggemma-300M (bfloat16) | 307.6M | 586.7 MB | |
| | Image encoder | MobileNetV4-Medium (timm) | 8.5M | 32.4 MB | |
| | Audio encoder | EfficientAT mn20_as | 18.0M | 68.5 MB | |
| | Image projection | DeepProjectionHead (1280 → 768) | 42.0M | 160.1 MB | |
| | Audio projection | DeepProjectionHead (1920 → 768) | 44.6M | 170.1 MB | |
| | **Total** | | **420.6M** | **1017.8 MB** | |
| |
| ### Projection head detail |
| |
| Each `DeepProjectionHead` is a residual MLP: |
| |
| ``` |
| Linear(encoder_dim, 4096) → GELU → LayerNorm → Dropout(0.3) |
| → Linear(4096, 4096) → GELU → LayerNorm → Dropout(0.3) + residual |
| → Linear(4096, 768) |
| ``` |
| |
| ### Matryoshka dimensions |
| |
| Embeddings can be truncated to `[768, 512, 256, 128]` dimensions while preserving retrieval quality — trained with Matryoshka Representation Learning (MRL) using weights `[1.0, 1.0, 2.0, 4.0]`. |
| |
| ## Benchmarks |
| |
| All benchmarks run on a single NVIDIA L4 GPU with 5K samples where applicable. |
| |
| ### Cross-modal retrieval — SALT (5K trimodal samples) |
| |
| | Direction | TEG-421M (421M) | LCO-3B (4.7B) | Nemotron-3B (4.7B) | ImageBind (1.2B) | EBind | |
| |---|---|---|---|---|---| |
| | Text → Image R@1 | 0.672 | 0.660 | 0.529 | 0.712 | **0.779** | |
| | Image → Text R@1 | 0.620 | 0.564 | 0.299 | 0.736 | **0.783** | |
| | Text → Audio R@1 | **0.113** | 0.042 | 0.018 | 0.038 | 0.047 | |
| | Audio → Text R@1 | **0.115** | 0.032 | 0.010 | 0.039 | 0.035 | |
| | Audio → Image R@1 | **0.081** | 0.027 | 0.016 | 0.023 | 0.027 | |
| | Image → Audio R@1 | **0.083** | 0.034 | 0.018 | 0.025 | 0.032 | |
| |
| TEG leads all audio cross-modal directions by 2-10x over models that are 3-11x larger. Image↔Audio improved ~40% over v1 via joint cross-modal training. Vision-text trails EBind/ImageBind but uses encoders small enough for edge deployment. |
| |
| ### Audio retrieval — AudioCaps & Clotho |
| |
| | Benchmark | Direction | TEG-421M | LCO-3B | Nemotron-3B | CLAP-Small | CLAP-Large | ImageBind | EBind | |
| |---|---|---|---|---|---|---|---|---| |
| | AudioCaps | A→T R@1 | 0.159 | 0.250 | 0.050 | **0.425** | 0.420 | 0.116 | 0.225 | |
| | AudioCaps | T→A R@1 | 0.149 | 0.215 | 0.075 | **0.315** | 0.280 | 0.080 | 0.219 | |
| | Clotho | A→T R@1 | 0.168 | 0.178 | 0.038 | 0.166 | **0.195** | 0.061 | 0.088 | |
| | Clotho | T→A R@1 | 0.123 | **0.187** | 0.070 | 0.159 | 0.167 | 0.074 | 0.118 | |
| |
| CLAP models lead on audio-only benchmarks (audio specialists with no image support). Among trimodal models, TEG is competitive with LCO while being 11x smaller. |
| |
| ### Image-text retrieval — MSCOCO & Flickr30k |
| |
| | Benchmark | Direction | TEG-421M (421M) | EBind (1.78B*) | ImageBind (1.2B) | LCO-3B (4.7B) | Nemotron-3B (4.7B) | |
| |---|---|---|---|---|---|---| |
| | MSCOCO 5K | I→T R@1 | 0.248 | **0.743** | 0.658 | 0.533 | 0.225 | |
| | MSCOCO 5K | T→I R@1 | 0.180 | **0.559** | 0.490 | 0.469 | 0.334 | |
| | MSCOCO 5K | I→T R@10 | 0.622 | **0.948** | 0.918 | 0.784 | 0.630 | |
| | Flickr30k | I→T R@1 | 0.498 | — | — | **0.840** | 0.419 | |
| | Flickr30k | T→I R@1 | 0.358 | — | — | **0.765** | 0.563 | |
| |
| TEG's image-text retrieval trades accuracy for edge deployability — MobileNetV4-Medium is ~100x smaller than the ViT-H/ViT-L encoders used by competitors. On MSCOCO, TEG outperforms Nemotron-3B on I→T despite being 11x smaller. |
| |
| ### Zero-shot classification — ESC-50 |
| |
| | Model | Params | Accuracy | |
| |---|---|---| |
| | CLAP-Large | 67.8M | **0.905** | |
| | LCO-3B | 4.7B | 0.853 | |
| | TEG-421M | 421M | 0.820 | |
| | EBind | 1.78B* | 0.770 | |
| | CLAP-Small | 27.5M | 0.751 | |
| | Nemotron-3B | 4.7B | 0.727 | |
| | ImageBind | 1.2B | 0.664 | |
| |
| |
| ## Usage |
| |
| ### Loading components |
| |
| ```python |
| from safetensors.torch import load_file |
| |
| # Load entire model |
| tensors = load_file("teg-421m.safetensors") |
|
|
| # Extract components by prefix |
| gemma_sd = {k.removeprefix("gemma."): v for k, v in tensors.items() if k.startswith("gemma.")} |
| image_enc_sd = {k.removeprefix("image_encoder."): v for k, v in tensors.items() if k.startswith("image_encoder.")} |
| audio_enc_sd = {k.removeprefix("audio_encoder."): v for k, v in tensors.items() if k.startswith("audio_encoder.")} |
| image_proj_sd = {k.removeprefix("image_projection."): v for k, v in tensors.items() if k.startswith("image_projection.")} |
| audio_proj_sd = {k.removeprefix("audio_projection."): v for k, v in tensors.items() if k.startswith("audio_projection.")} |
| ``` |
| |
| ### Reading metadata |
| |
| ```python |
| from safetensors import safe_open |
|
|
| with safe_open("teg-421m.safetensors", framework="pt") as f: |
| metadata = f.metadata() |
| print(metadata) |
| # Keys: format, version, text_model, embed_dim, image_encoder_name, |
| # image_encoder_dim, audio_encoder_name, audio_encoder_dim, |
| # audio_sample_rate, matryoshka_dims, total_parameters, ... |
| ``` |
| |
| ### Matryoshka truncation |
| |
| ```python |
| import torch.nn.functional as F |
| |
| # Full 768-dim embedding |
| embedding = model(input) # (N, 768) |
| |
| # Truncate to 256-dim and re-normalize |
| embedding_256 = F.normalize(embedding[:, :256], dim=-1) |
| ``` |
| |
| ## File layout |
| |
| ``` |
| teg-421m.safetensors # All components in one file (~1 GB) |
| ``` |
| |
| ### Tensor key prefixes |
| |
| | Prefix | Component | Tensors | |
| |---|---|---| |
| | `gemma.*` | embeddinggemma-300M (bfloat16) | 316 | |
| | `image_encoder.*` | MobileNetV4-Medium | 462 | |
| | `audio_encoder.*` | EfficientAT mn20_as | 312 | |
| | `image_projection.*` | Deep projection head | 14 | |
| | `audio_projection.*` | Deep projection head | 14 | |
| |
| ## Training |
| |
| - **Loss**: InfoNCE (contrastive) with Matryoshka Representation Learning |
| - **Data**: ~4.8M synthetically generated trimodal triplets (text, image, audio) |
| - **Hardware**: 2x NVIDIA L4 GPUs |
| - **Optimizer**: AdamW, lr=1e-3 (projections), weight decay=1e-4 |
| - **Epochs**: ~22 (early stopping on validation recall) |
| - **Projection heads only** — encoders and Gemma are frozen during training |
| |
| ### Design decisions |
| |
| - **Frozen source encoders**: MobileNetV4 and EfficientAT are kept frozen; only projection heads are trained via distillation into Gemma's space |
| - **Deep projection heads**: Residual MLPs with dropout outperformed shallow 2-layer heads, especially for audio |
| - **Matryoshka weighting**: Higher weight on smaller dimensions (4x at 128-dim) ensures quality at aggressive truncation levels |
| - **Edge-first**: Source encoders chosen for edge deployment — MobileNetV4-Medium and EfficientAT mn20 can run on devices like Raspberry Pi 5 |
| |
| *\*EBind's [HuggingFace checkpoint](https://huggingface.co/encord-team/ebind-full) is 8.93M parameters (bridge heads only), but inference requires frozen backbones (SigLIP ViT-L, CLAP HTSAT, text encoder) totaling 1.78B loaded parameters as measured by our benchmark harness.* |
| |
| ## Limitations |
| |
| - Audio retrieval lags behind specialist models like CLAP on audio-only benchmarks |
| - Image-text retrieval trades some accuracy vs larger vision encoders (SigLIP, CLIP ViT-L) for edge deployability |
| - Trained primarily on synthetic data — real-world distribution shifts may affect performance |
| - Text modality requires the full embeddinggemma-300M model (307M params, bfloat16) |
| |
| ## Links |
| |
| - **Website**: [augmem.ai](https://augmem.ai) |
| - **GitHub**: [github.com/augmem](https://github.com/augmem) |
| - **GGUF variant**: [augmem/teg-421m-gguf](https://huggingface.co/augmem/teg-421m-gguf) |
| |
| ## License |
| |
| Apache 2.0 |
| |