rnagabh's picture
Initial upload: Gemma 4 audio encoder (304.8M USM-style Conformer)
e3fb768 verified
---
language:
- en
- multilingual
license: apache-2.0
library_name: transformers
tags:
- feature-extraction
- audio
- speech
- conformer
- gemma4
- usm
- google
- safetensors
pipeline_tag: feature-extraction
base_model: google/gemma-4-E2B-it
model-index:
- name: gemma4-audio-encoder
results:
- task:
type: audio-classification
name: Speech Commands (35-class)
dataset:
name: Google Speech Commands v0.02
type: google/speech_commands
split: validation
metrics:
- type: accuracy
value: 72.0
name: Linear Probe Accuracy
---
# Gemma 4 Audio Encoder (USM-style Conformer)
Standalone extraction of the audio encoder from Google's [Gemma 4](https://huggingface.co/google/gemma-4-E2B-it) multimodal model family. This is a 304.8M parameter USM-style Conformer that converts audio waveforms (via 128-bin mel spectrogram) into embeddings.
**License:** Apache 2.0 (inherited from Gemma 4 — no restrictions)
## Architecture
| Property | Value |
|---|---|
| Total parameters | 304.8M |
| Architecture | USM-style Conformer (Macaron-net) |
| Hidden dimension | 1024 (pure audio representation) |
| Output dimension | 1536 (text-projected via `output_proj`) |
| Conformer layers | 12 |
| Attention heads | 8 (128 dim per head) |
| FFW intermediate | 4096 (4× expansion) |
| Depthwise conv kernel | 5 |
| Subsampling conv channels | [128, 32] |
| Input | 128-bin mel spectrogram @ 16kHz |
| Conformer activation | SiLU |
| Subsampling activation | ReLU |
| Conformer normalization | RMSNorm (eps=1e-6) |
| Subsampling normalization | LayerNorm |
| Residual weight | 0.5 (Macaron half-step) |
| Attention type | Chunked causal (chunk_size=12, left_context=13, right_context=0) |
| Clipped linears | Yes (quantization-ready input_min/max, output_min/max per layer) |
| Temporal downsampling | 4× (two stride-2 Conv2d layers) |
### Conformer Block Structure
Each of the 12 conformer blocks follows the Macaron-net pattern:
```
Input
→ FFW1: pre_layer_norm → Linear(1024→4096) → SiLU → Linear(4096→1024) → post_layer_norm
→ + 0.5 × residual
→ Self-Attention: norm_pre_attn → Q/K/V proj (1024→1024) → relative position → post proj → norm_post_attn
→ + residual
→ LightConv1d: pre_layer_norm → Linear(1024→2048, gated) → DepthwiseConv1d(k=5) → conv_norm → Linear(1024→1024)
→ + residual
→ FFW2: pre_layer_norm → Linear(1024→4096) → SiLU → Linear(4096→1024) → post_layer_norm
→ + 0.5 × residual
→ norm_out
```
### Input/Output Shapes
- **Input:** `(batch, time_frames, 128)` — 128-bin mel features, time-first
- **Output:** `(batch, time_frames/4, 1536)` — 4× temporal downsampling, projected to 1536
- For 4 seconds of 16kHz audio: input ~(1, 399, 128) → output ~(1, 100, 1536)
## Usage
```python
import torch
import numpy as np
from transformers import Gemma4AudioModel, Gemma4AudioFeatureExtractor
# Load audio encoder directly from this repo
audio_tower = Gemma4AudioModel.from_pretrained(
"rnagabh/gemma4-audio-encoder",
torch_dtype=torch.bfloat16,
)
audio_tower.to("cuda")
audio_tower.eval()
# Load feature extractor (saved in this repo)
feature_extractor = Gemma4AudioFeatureExtractor.from_pretrained("rnagabh/gemma4-audio-encoder")
# Extract features from audio
waveform = np.random.randn(64000).astype(np.float32) # 4s @ 16kHz
inputs = feature_extractor([waveform], sampling_rate=16000, return_tensors="pt")
with torch.no_grad():
mel = inputs["input_features"].to(dtype=torch.bfloat16, device="cuda")
# === Option 1: Text-projected embeddings (1536-dim) ===
# Use this if feeding into an LLM or need the full model output.
output = audio_tower(mel)
text_projected = output.last_hidden_state # (1, 100, 1536)
# === Option 2: Pure audio embeddings (1024-dim) ===
# Captures the conformer output BEFORE the text projection layer.
# Recommended for downstream audio tasks (classification, verification, etc.)
# Note: this registers a hook and runs a separate forward pass.
pre_proj_features = {}
def hook_fn(module, input, output):
pre_proj_features["hidden"] = input[0]
handle = audio_tower.output_proj.register_forward_hook(hook_fn)
with torch.no_grad():
_ = audio_tower(mel)
handle.remove()
audio_embeddings = pre_proj_features["hidden"] # (1, 100, 1024)
```
> **Which to use?** For audio-only tasks (classification, speaker verification, deepfake detection),
> the 1024-dim pre-projection embeddings are better — they retain acoustic detail that the
> text projection discards. The 1536-dim output is designed for feeding into an LLM decoder.
## Critical: The AutoModel Loading Gotcha
⚠️ **`AutoModel.from_pretrained("google/gemma-4-E2B-it")` silently fails to load audio tower weights.**
All audio tower parameters initialize as random (std ≈ 0.02). The model runs without errors, produces outputs of the correct shape, but the outputs are meaningless.
**Root cause:** The checkpoint stores keys with a `model.` prefix (e.g., `model.audio_tower.layers.0...`). `AutoModel` builds the module tree expecting keys without the prefix. The mismatch causes every key to be both UNEXPECTED and MISSING. Transformers loads with `strict=False` by default, so this silently initializes everything fresh.
**Fix:** Use `AutoModelForMultimodalLM` instead:
```python
# ❌ WRONG — audio tower weights are randomly initialized
model = AutoModel.from_pretrained("google/gemma-4-E2B-it")
audio_tower = model.audio_tower # RANDOM WEIGHTS
# ✅ CORRECT — audio tower weights load properly
model = AutoModelForMultimodalLM.from_pretrained("google/gemma-4-E2B-it")
audio_tower = model.model.audio_tower # TRAINED WEIGHTS
```
**How to verify:**
```python
w = audio_tower.output_proj.weight.float()
print(f"std={w.std().item():.6f}")
# ✅ Trained: std ≈ 0.031250
# ❌ Random: std ≈ 0.019884
```
## E2B and E4B Share Identical Audio Weights
The audio encoder weights are **byte-for-byte identical** between Gemma 4 E2B and E4B. This was verified empirically — all 751 parameter tensors match exactly.
Gemma 4's E2B is a MatFormer sub-model nested inside E4B. The MatFormer architecture only affects the text decoder's feed-forward dimensions. The audio tower sits outside the MatFormer nesting and is a shared module.
**Implication:** There is no reason to prefer E4B over E2B for audio encoder extraction. E2B is a smaller download (~10GB vs ~16GB).
## Files in This Repo
| File | Description | Size |
|---|---|---|
| `config.json` | Audio tower config (Gemma4AudioConfig) | <1 KB |
| `model.safetensors` | Audio tower weights (304.8M params, BF16) | 609.7 MB |
| `preprocessor_config.json` | Mel spectrogram feature extractor config | <1 KB |
| `embed_audio.safetensors` | Audio→text embedding projection (1536→1536) | 4.7 MB |
## Limitations
- **End-to-end trained for LLM decoding:** The encoder was trained to produce features for Gemma 4's text decoder, not as a general-purpose audio encoder. For standalone feature extraction, the 1024-dim pre-projection output (before `output_proj`) may be more useful than the 1536-dim post-projection output.
- **Causal chunked attention:** The encoder uses right_context=0, meaning it cannot look ahead. This limits its use in offline/non-streaming settings compared to bidirectional encoders.
- **Multi-layer fusion doesn't help:** Unlike wav2vec2/W2v-BERT where combining multiple hidden layers improves downstream performance, this encoder's Macaron half-step residuals and causal attention mean only the final layer output is useful.
- **Subsampling frontend uses ReLU + LayerNorm** (not SiLU + GroupNorm as in some USM descriptions).
- **Not a speaker encoder:** While embeddings show some speaker separation (cosine similarity gap of ~0.03), this model was not trained for speaker verification. Dedicated speaker models will significantly outperform it on speaker tasks.
## Benchmark Results (frozen 1024-dim embeddings, linear probe)
### Speech Commands Classification (35 classes)
| Metric | Value |
|---|---|
| Linear probe accuracy | **72.0%** |
| Random baseline | 2.9% |
| Improvement over chance | **25×** |
| Dataset | Google Speech Commands v0.02 (validation) |
| Probe | Logistic regression on L2-normalized mean-pooled embeddings |
The encoder captures rich phonetic and semantic content — strong on acoustically distinct words (seven: 0.93 F1, house/stop/eight: 0.89 F1) and weaker on similar-sounding pairs (three/tree).
### Speaker Similarity (LibriSpeech test-clean)
| Metric | Value |
|---|---|
| Same-speaker cosine similarity | 0.656 ± 0.147 |
| Different-speaker cosine similarity | 0.622 ± 0.132 |
| Separation gap | 0.034 |
Modest speaker separation — expected since this is an ASR-oriented encoder, not a speaker verification model.
![Speaker Similarity Distribution](gemma4_speaker_similarity.png)
### t-SNE Speaker Clustering
![t-SNE Speaker Embeddings](gemma4_tsne_speakers.png)
Embeddings show partial speaker clustering — the encoder captures speaker characteristics as a byproduct of ASR training, but is not optimized for speaker discrimination.
## Extraction Details
- Extracted from `google/gemma-4-E2B-it` using `AutoModelForMultimodalLM`
- Weights saved in BF16 as safetensors
- Forward pass verified: extracted model produces outputs with **0.0 max absolute difference** from the original
- All architecture specs independently verified against the live model
## References
- [Gemma 4 on HuggingFace](https://huggingface.co/google/gemma-4-E2B-it)
- [Gemma 4 Blog Post](https://huggingface.co/blog/gemma4)
- [Google USM Paper](https://arxiv.org/abs/2303.01037) — "Google USM: Scaling Automatic Speech Recognition Beyond 100 Languages"