rnagabh commited on
Commit
e3fb768
·
verified ·
1 Parent(s): 2eb1be3

Initial upload: Gemma 4 audio encoder (304.8M USM-style Conformer)

Browse files
Files changed (1) hide show
  1. README.md +17 -13
README.md CHANGED
@@ -106,22 +106,26 @@ waveform = np.random.randn(64000).astype(np.float32) # 4s @ 16kHz
106
  inputs = feature_extractor([waveform], sampling_rate=16000, return_tensors="pt")
107
 
108
  with torch.no_grad():
109
- output = audio_tower(inputs["input_features"].to(dtype=torch.bfloat16, device="cuda"))
110
 
111
- # Option 1: Text-projected embeddings (1536-dim) — maps into Gemma 4 text decoder space
 
 
112
  text_projected = output.last_hidden_state # (1, 100, 1536)
113
 
114
- # Option 2: Pure audio embeddings (1024-dim) — conformer output before projection
115
- # Recommended for downstream audio tasks (classification, verification, etc.)
116
- # Use a forward hook to capture the 1024-dim input to output_proj
117
- pre_proj_features = {}
118
- def hook_fn(module, input, output):
119
- pre_proj_features["hidden"] = input[0]
120
-
121
- handle = audio_tower.output_proj.register_forward_hook(hook_fn)
122
- _ = audio_tower(inputs["input_features"].to(dtype=torch.bfloat16, device="cuda"))
123
- handle.remove()
124
- audio_embeddings = pre_proj_features["hidden"] # (1, 100, 1024)
 
 
125
  ```
126
 
127
  > **Which to use?** For audio-only tasks (classification, speaker verification, deepfake detection),
 
106
  inputs = feature_extractor([waveform], sampling_rate=16000, return_tensors="pt")
107
 
108
  with torch.no_grad():
109
+ mel = inputs["input_features"].to(dtype=torch.bfloat16, device="cuda")
110
 
111
+ # === Option 1: Text-projected embeddings (1536-dim) ===
112
+ # Use this if feeding into an LLM or need the full model output.
113
+ output = audio_tower(mel)
114
  text_projected = output.last_hidden_state # (1, 100, 1536)
115
 
116
+ # === Option 2: Pure audio embeddings (1024-dim) ===
117
+ # Captures the conformer output BEFORE the text projection layer.
118
+ # Recommended for downstream audio tasks (classification, verification, etc.)
119
+ # Note: this registers a hook and runs a separate forward pass.
120
+ pre_proj_features = {}
121
+ def hook_fn(module, input, output):
122
+ pre_proj_features["hidden"] = input[0]
123
+
124
+ handle = audio_tower.output_proj.register_forward_hook(hook_fn)
125
+ with torch.no_grad():
126
+ _ = audio_tower(mel)
127
+ handle.remove()
128
+ audio_embeddings = pre_proj_features["hidden"] # (1, 100, 1024)
129
  ```
130
 
131
  > **Which to use?** For audio-only tasks (classification, speaker verification, deepfake detection),