mazesmazes commited on
Commit
216def4
·
verified ·
1 Parent(s): e53b8de

Training in progress - step 1000

Browse files
Files changed (3) hide show
  1. config.json +1 -1
  2. model.safetensors +2 -2
  3. projectors.py +9 -18
config.json CHANGED
@@ -262,7 +262,7 @@
262
  "pad_token_id": 151643,
263
  "pipeline_tag": "automatic-speech-recognition",
264
  "pretrained_model_path": "mazesmazes/tiny-audio-next",
265
- "projector_hidden_dim": 2048,
266
  "projector_pool_stride": 4,
267
  "projector_type": "mlp",
268
  "qformer_hidden_size": null,
 
262
  "pad_token_id": 151643,
263
  "pipeline_tag": "automatic-speech-recognition",
264
  "pretrained_model_path": "mazesmazes/tiny-audio-next",
265
+ "projector_hidden_dim": 4096,
266
  "projector_pool_stride": 4,
267
  "projector_type": "mlp",
268
  "qformer_hidden_size": null,
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2877043b9b29a7fa8dd432ea544360d012f401fdc6f29ec07847d65a7d206413
3
- size 2433494416
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a0b0a7491589ab19652cb41bb5aad27c3a78296bb3a8a66ea5beb3f99a3a81a
3
+ size 2483834256
projectors.py CHANGED
@@ -21,19 +21,15 @@ from transformers.models.llama.modeling_llama import LlamaRMSNorm
21
 
22
 
23
  class MLPAudioProjector(nn.Module):
24
- """2-layer MLP projector with frame-stacking downsampling (matches GLM-ASR)."""
25
-
26
- # RMSNorm init weight chosen to match Qwen3-0.6B's embed_tokens median
27
- # RMS (empirically 0.0292 across 151,936 tokens × dim=1024). With this
28
- # init the projector outputs enter the LM at the same per-position
29
- # residual-stream magnitude as text embed_tokens avoiding the
30
- # ~34× over-magnitude that LlamaRMSNorm's default weight=1.0 produces.
31
- # Adam's per-parameter normalization means this small init does NOT
32
- # starve projector gradient flow; the norm-before-GELU placement keeps
33
- # gradients healthy regardless of init magnitude. If you swap to a
34
- # different LM, re-measure with
35
- # `model.get_input_embeddings().weight.pow(2).mean().sqrt()` and update.
36
- _NORM_INIT = 0.029
37
 
38
  def __init__(self, config):
39
  """Initialize MLP projector.
@@ -53,14 +49,9 @@ class MLPAudioProjector(nn.Module):
53
  hidden_dim = getattr(config, "projector_hidden_dim", None) or llm_dim
54
  self.linear_1 = nn.Linear(in_dim, hidden_dim, bias=False)
55
  self.norm = LlamaRMSNorm(hidden_dim, eps=1e-6)
56
- self.norm.weight.data.fill_(self._NORM_INIT)
57
  self.act = nn.GELU()
58
  self.linear_2 = nn.Linear(hidden_dim, llm_dim, bias=False)
59
- # Output norm aligns the projector's RMS with the LM's embed_tokens
60
- # distribution. See _NORM_INIT comment above for the magnitude
61
- # derivation.
62
  self.norm_2 = LlamaRMSNorm(llm_dim, eps=1e-6)
63
- self.norm_2.weight.data.fill_(self._NORM_INIT)
64
 
65
  def get_output_length(self, input_length: int) -> int:
66
  """Calculate output sequence length given input length (matches GLM-ASR)."""
 
21
 
22
 
23
  class MLPAudioProjector(nn.Module):
24
+ """2-layer MLP projector with frame-stacking downsampling (matches GLM-ASR).
25
+
26
+ Both RMSNorms use LlamaRMSNorm's default weight=1.0 init. A prior version
27
+ initialized both to 0.029 (Qwen3-0.6B's embed_tokens RMS) to put projector
28
+ outputs at residual-stream scale on step 1. Empirically, after training the
29
+ model drifted both norms back to ~1.0 (norm) and ~1.2 (norm_2) — the small
30
+ init wasted compute on a 35× scale-correction phase the optimizer would
31
+ have skipped from default init.
32
+ """
 
 
 
 
33
 
34
  def __init__(self, config):
35
  """Initialize MLP projector.
 
49
  hidden_dim = getattr(config, "projector_hidden_dim", None) or llm_dim
50
  self.linear_1 = nn.Linear(in_dim, hidden_dim, bias=False)
51
  self.norm = LlamaRMSNorm(hidden_dim, eps=1e-6)
 
52
  self.act = nn.GELU()
53
  self.linear_2 = nn.Linear(hidden_dim, llm_dim, bias=False)
 
 
 
54
  self.norm_2 = LlamaRMSNorm(llm_dim, eps=1e-6)
 
55
 
56
  def get_output_length(self, input_length: int) -> int:
57
  """Calculate output sequence length given input length (matches GLM-ASR)."""