Initial upload: Gemma 4 audio encoder (304.8M USM-style Conformer)
Browse files
README.md
CHANGED
|
@@ -106,22 +106,26 @@ waveform = np.random.randn(64000).astype(np.float32) # 4s @ 16kHz
|
|
| 106 |
inputs = feature_extractor([waveform], sampling_rate=16000, return_tensors="pt")
|
| 107 |
|
| 108 |
with torch.no_grad():
|
| 109 |
-
|
| 110 |
|
| 111 |
-
# Option 1: Text-projected embeddings (1536-dim)
|
|
|
|
|
|
|
| 112 |
text_projected = output.last_hidden_state # (1, 100, 1536)
|
| 113 |
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
|
|
|
|
|
|
| 125 |
```
|
| 126 |
|
| 127 |
> **Which to use?** For audio-only tasks (classification, speaker verification, deepfake detection),
|
|
|
|
| 106 |
inputs = feature_extractor([waveform], sampling_rate=16000, return_tensors="pt")
|
| 107 |
|
| 108 |
with torch.no_grad():
|
| 109 |
+
mel = inputs["input_features"].to(dtype=torch.bfloat16, device="cuda")
|
| 110 |
|
| 111 |
+
# === Option 1: Text-projected embeddings (1536-dim) ===
|
| 112 |
+
# Use this if feeding into an LLM or need the full model output.
|
| 113 |
+
output = audio_tower(mel)
|
| 114 |
text_projected = output.last_hidden_state # (1, 100, 1536)
|
| 115 |
|
| 116 |
+
# === Option 2: Pure audio embeddings (1024-dim) ===
|
| 117 |
+
# Captures the conformer output BEFORE the text projection layer.
|
| 118 |
+
# Recommended for downstream audio tasks (classification, verification, etc.)
|
| 119 |
+
# Note: this registers a hook and runs a separate forward pass.
|
| 120 |
+
pre_proj_features = {}
|
| 121 |
+
def hook_fn(module, input, output):
|
| 122 |
+
pre_proj_features["hidden"] = input[0]
|
| 123 |
+
|
| 124 |
+
handle = audio_tower.output_proj.register_forward_hook(hook_fn)
|
| 125 |
+
with torch.no_grad():
|
| 126 |
+
_ = audio_tower(mel)
|
| 127 |
+
handle.remove()
|
| 128 |
+
audio_embeddings = pre_proj_features["hidden"] # (1, 100, 1024)
|
| 129 |
```
|
| 130 |
|
| 131 |
> **Which to use?** For audio-only tasks (classification, speaker verification, deepfake detection),
|