| --- |
| language: |
| - en |
| - multilingual |
| license: apache-2.0 |
| library_name: transformers |
| tags: |
| - feature-extraction |
| - audio |
| - speech |
| - conformer |
| - gemma4 |
| - usm |
| - google |
| - safetensors |
| pipeline_tag: feature-extraction |
| base_model: google/gemma-4-E2B-it |
| model-index: |
| - name: gemma4-audio-encoder |
| results: |
| - task: |
| type: audio-classification |
| name: Speech Commands (35-class) |
| dataset: |
| name: Google Speech Commands v0.02 |
| type: google/speech_commands |
| split: validation |
| metrics: |
| - type: accuracy |
| value: 72.0 |
| name: Linear Probe Accuracy |
| --- |
| |
| # Gemma 4 Audio Encoder (USM-style Conformer) |
|
|
| Standalone extraction of the audio encoder from Google's [Gemma 4](https://huggingface.co/google/gemma-4-E2B-it) multimodal model family. This is a 304.8M parameter USM-style Conformer that converts audio waveforms (via 128-bin mel spectrogram) into embeddings. |
|
|
| **License:** Apache 2.0 (inherited from Gemma 4 — no restrictions) |
|
|
| ## Architecture |
|
|
| | Property | Value | |
| |---|---| |
| | Total parameters | 304.8M | |
| | Architecture | USM-style Conformer (Macaron-net) | |
| | Hidden dimension | 1024 (pure audio representation) | |
| | Output dimension | 1536 (text-projected via `output_proj`) | |
| | Conformer layers | 12 | |
| | Attention heads | 8 (128 dim per head) | |
| | FFW intermediate | 4096 (4× expansion) | |
| | Depthwise conv kernel | 5 | |
| | Subsampling conv channels | [128, 32] | |
| | Input | 128-bin mel spectrogram @ 16kHz | |
| | Conformer activation | SiLU | |
| | Subsampling activation | ReLU | |
| | Conformer normalization | RMSNorm (eps=1e-6) | |
| | Subsampling normalization | LayerNorm | |
| | Residual weight | 0.5 (Macaron half-step) | |
| | Attention type | Chunked causal (chunk_size=12, left_context=13, right_context=0) | |
| | Clipped linears | Yes (quantization-ready input_min/max, output_min/max per layer) | |
| | Temporal downsampling | 4× (two stride-2 Conv2d layers) | |
| |
| ### Conformer Block Structure |
| |
| Each of the 12 conformer blocks follows the Macaron-net pattern: |
| |
| ``` |
| Input |
| → FFW1: pre_layer_norm → Linear(1024→4096) → SiLU → Linear(4096→1024) → post_layer_norm |
| → + 0.5 × residual |
| → Self-Attention: norm_pre_attn → Q/K/V proj (1024→1024) → relative position → post proj → norm_post_attn |
| → + residual |
| → LightConv1d: pre_layer_norm → Linear(1024→2048, gated) → DepthwiseConv1d(k=5) → conv_norm → Linear(1024→1024) |
| → + residual |
| → FFW2: pre_layer_norm → Linear(1024→4096) → SiLU → Linear(4096→1024) → post_layer_norm |
| → + 0.5 × residual |
| → norm_out |
| ``` |
| |
| ### Input/Output Shapes |
| |
| - **Input:** `(batch, time_frames, 128)` — 128-bin mel features, time-first |
| - **Output:** `(batch, time_frames/4, 1536)` — 4× temporal downsampling, projected to 1536 |
| - For 4 seconds of 16kHz audio: input ~(1, 399, 128) → output ~(1, 100, 1536) |
|
|
| ## Usage |
|
|
| ```python |
| import torch |
| import numpy as np |
| from transformers import Gemma4AudioModel, Gemma4AudioFeatureExtractor |
| |
| # Load audio encoder directly from this repo |
| audio_tower = Gemma4AudioModel.from_pretrained( |
| "rnagabh/gemma4-audio-encoder", |
| torch_dtype=torch.bfloat16, |
| ) |
| audio_tower.to("cuda") |
| audio_tower.eval() |
| |
| # Load feature extractor (saved in this repo) |
| feature_extractor = Gemma4AudioFeatureExtractor.from_pretrained("rnagabh/gemma4-audio-encoder") |
| |
| # Extract features from audio |
| waveform = np.random.randn(64000).astype(np.float32) # 4s @ 16kHz |
| inputs = feature_extractor([waveform], sampling_rate=16000, return_tensors="pt") |
| |
| with torch.no_grad(): |
| mel = inputs["input_features"].to(dtype=torch.bfloat16, device="cuda") |
| |
| # === Option 1: Text-projected embeddings (1536-dim) === |
| # Use this if feeding into an LLM or need the full model output. |
| output = audio_tower(mel) |
| text_projected = output.last_hidden_state # (1, 100, 1536) |
| |
| # === Option 2: Pure audio embeddings (1024-dim) === |
| # Captures the conformer output BEFORE the text projection layer. |
| # Recommended for downstream audio tasks (classification, verification, etc.) |
| # Note: this registers a hook and runs a separate forward pass. |
| pre_proj_features = {} |
| def hook_fn(module, input, output): |
| pre_proj_features["hidden"] = input[0] |
| |
| handle = audio_tower.output_proj.register_forward_hook(hook_fn) |
| with torch.no_grad(): |
| _ = audio_tower(mel) |
| handle.remove() |
| audio_embeddings = pre_proj_features["hidden"] # (1, 100, 1024) |
| ``` |
|
|
| > **Which to use?** For audio-only tasks (classification, speaker verification, deepfake detection), |
| > the 1024-dim pre-projection embeddings are better — they retain acoustic detail that the |
| > text projection discards. The 1536-dim output is designed for feeding into an LLM decoder. |
|
|
| ## Critical: The AutoModel Loading Gotcha |
|
|
| ⚠️ **`AutoModel.from_pretrained("google/gemma-4-E2B-it")` silently fails to load audio tower weights.** |
| |
| All audio tower parameters initialize as random (std ≈ 0.02). The model runs without errors, produces outputs of the correct shape, but the outputs are meaningless. |
| |
| **Root cause:** The checkpoint stores keys with a `model.` prefix (e.g., `model.audio_tower.layers.0...`). `AutoModel` builds the module tree expecting keys without the prefix. The mismatch causes every key to be both UNEXPECTED and MISSING. Transformers loads with `strict=False` by default, so this silently initializes everything fresh. |
| |
| **Fix:** Use `AutoModelForMultimodalLM` instead: |
|
|
| ```python |
| # ❌ WRONG — audio tower weights are randomly initialized |
| model = AutoModel.from_pretrained("google/gemma-4-E2B-it") |
| audio_tower = model.audio_tower # RANDOM WEIGHTS |
| |
| # ✅ CORRECT — audio tower weights load properly |
| model = AutoModelForMultimodalLM.from_pretrained("google/gemma-4-E2B-it") |
| audio_tower = model.model.audio_tower # TRAINED WEIGHTS |
| ``` |
|
|
| **How to verify:** |
|
|
| ```python |
| w = audio_tower.output_proj.weight.float() |
| print(f"std={w.std().item():.6f}") |
| # ✅ Trained: std ≈ 0.031250 |
| # ❌ Random: std ≈ 0.019884 |
| ``` |
|
|
| ## E2B and E4B Share Identical Audio Weights |
|
|
| The audio encoder weights are **byte-for-byte identical** between Gemma 4 E2B and E4B. This was verified empirically — all 751 parameter tensors match exactly. |
|
|
| Gemma 4's E2B is a MatFormer sub-model nested inside E4B. The MatFormer architecture only affects the text decoder's feed-forward dimensions. The audio tower sits outside the MatFormer nesting and is a shared module. |
|
|
| **Implication:** There is no reason to prefer E4B over E2B for audio encoder extraction. E2B is a smaller download (~10GB vs ~16GB). |
|
|
| ## Files in This Repo |
|
|
| | File | Description | Size | |
| |---|---|---| |
| | `config.json` | Audio tower config (Gemma4AudioConfig) | <1 KB | |
| | `model.safetensors` | Audio tower weights (304.8M params, BF16) | 609.7 MB | |
| | `preprocessor_config.json` | Mel spectrogram feature extractor config | <1 KB | |
| | `embed_audio.safetensors` | Audio→text embedding projection (1536→1536) | 4.7 MB | |
|
|
| ## Limitations |
|
|
| - **End-to-end trained for LLM decoding:** The encoder was trained to produce features for Gemma 4's text decoder, not as a general-purpose audio encoder. For standalone feature extraction, the 1024-dim pre-projection output (before `output_proj`) may be more useful than the 1536-dim post-projection output. |
| - **Causal chunked attention:** The encoder uses right_context=0, meaning it cannot look ahead. This limits its use in offline/non-streaming settings compared to bidirectional encoders. |
| - **Multi-layer fusion doesn't help:** Unlike wav2vec2/W2v-BERT where combining multiple hidden layers improves downstream performance, this encoder's Macaron half-step residuals and causal attention mean only the final layer output is useful. |
| - **Subsampling frontend uses ReLU + LayerNorm** (not SiLU + GroupNorm as in some USM descriptions). |
| - **Not a speaker encoder:** While embeddings show some speaker separation (cosine similarity gap of ~0.03), this model was not trained for speaker verification. Dedicated speaker models will significantly outperform it on speaker tasks. |
| |
| ## Benchmark Results (frozen 1024-dim embeddings, linear probe) |
| |
| ### Speech Commands Classification (35 classes) |
| |
| | Metric | Value | |
| |---|---| |
| | Linear probe accuracy | **72.0%** | |
| | Random baseline | 2.9% | |
| | Improvement over chance | **25×** | |
| | Dataset | Google Speech Commands v0.02 (validation) | |
| | Probe | Logistic regression on L2-normalized mean-pooled embeddings | |
| |
| The encoder captures rich phonetic and semantic content — strong on acoustically distinct words (seven: 0.93 F1, house/stop/eight: 0.89 F1) and weaker on similar-sounding pairs (three/tree). |
| |
| ### Speaker Similarity (LibriSpeech test-clean) |
| |
| | Metric | Value | |
| |---|---| |
| | Same-speaker cosine similarity | 0.656 ± 0.147 | |
| | Different-speaker cosine similarity | 0.622 ± 0.132 | |
| | Separation gap | 0.034 | |
| |
| Modest speaker separation — expected since this is an ASR-oriented encoder, not a speaker verification model. |
| |
|  |
| |
| ### t-SNE Speaker Clustering |
| |
|  |
| |
| Embeddings show partial speaker clustering — the encoder captures speaker characteristics as a byproduct of ASR training, but is not optimized for speaker discrimination. |
| |
| ## Extraction Details |
| |
| - Extracted from `google/gemma-4-E2B-it` using `AutoModelForMultimodalLM` |
| - Weights saved in BF16 as safetensors |
| - Forward pass verified: extracted model produces outputs with **0.0 max absolute difference** from the original |
| - All architecture specs independently verified against the live model |
| |
| ## References |
| |
| - [Gemma 4 on HuggingFace](https://huggingface.co/google/gemma-4-E2B-it) |
| - [Gemma 4 Blog Post](https://huggingface.co/blog/gemma4) |
| - [Google USM Paper](https://arxiv.org/abs/2303.01037) — "Google USM: Scaling Automatic Speech Recognition Beyond 100 Languages" |