mazesmazes commited on
Commit
cd1bcf8
·
verified ·
1 Parent(s): 33a7a22

Training in progress - step 1000

Browse files
Files changed (5) hide show
  1. asr_config.py +0 -7
  2. asr_modeling.py +0 -24
  3. config.json +0 -1
  4. model.safetensors +1 -1
  5. projectors.py +16 -3
asr_config.py CHANGED
@@ -51,12 +51,6 @@ class ASRConfig(transformers.PretrainedConfig):
51
  downsample_rate: int = 5, # Granite default
52
  projector_hidden_dim: Optional[int] = None,
53
  projector_type: str = "mlp", # "mlp", "mosa", "moe", "qformer"
54
- # Per-time-step Bernoulli zero-mask on encoder output before the
55
- # projector (training-only). 0.05–0.15 is the SpecAugment-equivalent
56
- # range for frozen-encoder setups; drops whole encoder frames so
57
- # the projector learns robustness to missing context. No magnitude
58
- # rescaling. 0.0 disables.
59
- audio_token_dropout: float = 0.0,
60
  # MoE-specific configuration
61
  num_experts: int = 4, # Number of experts in MoE projectors
62
  num_experts_per_tok: int = 2, # Top-k experts per token
@@ -123,7 +117,6 @@ class ASRConfig(transformers.PretrainedConfig):
123
  self.downsample_rate = downsample_rate
124
  self.projector_hidden_dim = projector_hidden_dim
125
  self.projector_type = projector_type
126
- self.audio_token_dropout = audio_token_dropout
127
  # MoE-specific configuration
128
  self.num_experts = num_experts
129
  self.num_experts_per_tok = num_experts_per_tok
 
51
  downsample_rate: int = 5, # Granite default
52
  projector_hidden_dim: Optional[int] = None,
53
  projector_type: str = "mlp", # "mlp", "mosa", "moe", "qformer"
 
 
 
 
 
 
54
  # MoE-specific configuration
55
  num_experts: int = 4, # Number of experts in MoE projectors
56
  num_experts_per_tok: int = 2, # Top-k experts per token
 
117
  self.downsample_rate = downsample_rate
118
  self.projector_hidden_dim = projector_hidden_dim
119
  self.projector_type = projector_type
 
120
  # MoE-specific configuration
121
  self.num_experts = num_experts
122
  self.num_experts_per_tok = num_experts_per_tok
asr_modeling.py CHANGED
@@ -449,35 +449,11 @@ class ASRModel(PreTrainedModel, GenerationMixin):
449
  encoder_out = self.audio_tower(input_features=audio_features)
450
  hidden_states = encoder_out.last_hidden_state
451
 
452
- hidden_states = self._maybe_drop_audio_tokens(hidden_states)
453
  audio_embeds = self.projector(hidden_states)
454
 
455
  token_counts = expected_token_counts.to(device=audio_embeds.device, dtype=torch.long)
456
  return _gather_audio_embeds(audio_embeds, token_counts)
457
 
458
- def _maybe_drop_audio_tokens(self, hidden_states: torch.Tensor) -> torch.Tensor:
459
- """Per-time-step Bernoulli zero-mask on encoder output (train-only).
460
-
461
- SpecAugment-equivalent for frozen-encoder setups: drops whole frames
462
- from the encoder output sequence so the projector learns robustness
463
- to missing context. Length-preserving (zeros, not deletions) so
464
- audio token counts in the prompt stay consistent. No magnitude
465
- rescaling — the projector should not learn to compensate.
466
- """
467
- p = float(getattr(self.config, "audio_token_dropout", 0.0))
468
- if not self.training or p <= 0.0:
469
- return hidden_states
470
- keep = 1.0 - p
471
- mask = torch.bernoulli(
472
- torch.full(
473
- hidden_states.shape[:-1],
474
- keep,
475
- device=hidden_states.device,
476
- dtype=hidden_states.dtype,
477
- )
478
- ).unsqueeze(-1)
479
- return hidden_states * mask
480
-
481
  def forward(
482
  self,
483
  input_ids: Optional[torch.Tensor] = None,
 
449
  encoder_out = self.audio_tower(input_features=audio_features)
450
  hidden_states = encoder_out.last_hidden_state
451
 
 
452
  audio_embeds = self.projector(hidden_states)
453
 
454
  token_counts = expected_token_counts.to(device=audio_embeds.device, dtype=torch.long)
455
  return _gather_audio_embeds(audio_embeds, token_counts)
456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
  def forward(
458
  self,
459
  input_ids: Optional[torch.Tensor] = None,
config.json CHANGED
@@ -103,7 +103,6 @@
103
  },
104
  "audio_model_id": "zai-org/GLM-ASR-Nano-2512",
105
  "audio_sample_rate": 16000,
106
- "audio_token_dropout": 0.1,
107
  "auto_map": {
108
  "AutoConfig": "asr_config.ASRConfig",
109
  "AutoModel": "asr_modeling.ASRModel",
 
103
  },
104
  "audio_model_id": "zai-org/GLM-ASR-Nano-2512",
105
  "audio_sample_rate": 16000,
 
106
  "auto_map": {
107
  "AutoConfig": "asr_config.ASRConfig",
108
  "AutoModel": "asr_modeling.ASRModel",
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a78c24b6c8640d257dee04d0bf3c63c2020dcc24ed38b1112de8c0d77d930384
3
  size 2433494416
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:69814e5212595ce41dbec818e0a9fcbc59fb014dd15e4470c5ecc9acb33fff17
3
  size 2433494416
projectors.py CHANGED
@@ -23,6 +23,18 @@ from transformers.models.llama.modeling_llama import LlamaRMSNorm
23
  class MLPAudioProjector(nn.Module):
24
  """2-layer MLP projector with frame-stacking downsampling (matches GLM-ASR)."""
25
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def __init__(self, config):
27
  """Initialize MLP projector.
28
 
@@ -41,13 +53,14 @@ class MLPAudioProjector(nn.Module):
41
  hidden_dim = getattr(config, "projector_hidden_dim", None) or llm_dim
42
  self.linear_1 = nn.Linear(in_dim, hidden_dim, bias=False)
43
  self.norm = LlamaRMSNorm(hidden_dim, eps=1e-6)
 
44
  self.act = nn.GELU()
45
  self.linear_2 = nn.Linear(hidden_dim, llm_dim, bias=False)
46
  # Output norm aligns the projector's RMS with the LM's embed_tokens
47
- # distribution. Without it, linear_2's Kaiming-uniform init produces
48
- # outputs ~30× quieter than embed rows, which saturates softmax at
49
- # audio positions and starves them of gradient.
50
  self.norm_2 = LlamaRMSNorm(llm_dim, eps=1e-6)
 
51
 
52
  def get_output_length(self, input_length: int) -> int:
53
  """Calculate output sequence length given input length (matches GLM-ASR)."""
 
23
  class MLPAudioProjector(nn.Module):
24
  """2-layer MLP projector with frame-stacking downsampling (matches GLM-ASR)."""
25
 
26
+ # RMSNorm init weight chosen to match Qwen3-0.6B's embed_tokens median
27
+ # RMS (empirically 0.0292 across 151,936 tokens × dim=1024). With this
28
+ # init the projector outputs enter the LM at the same per-position
29
+ # residual-stream magnitude as text embed_tokens — avoiding the
30
+ # ~34× over-magnitude that LlamaRMSNorm's default weight=1.0 produces.
31
+ # Adam's per-parameter normalization means this small init does NOT
32
+ # starve projector gradient flow; the norm-before-GELU placement keeps
33
+ # gradients healthy regardless of init magnitude. If you swap to a
34
+ # different LM, re-measure with
35
+ # `model.get_input_embeddings().weight.pow(2).mean().sqrt()` and update.
36
+ _NORM_INIT = 0.029
37
+
38
  def __init__(self, config):
39
  """Initialize MLP projector.
40
 
 
53
  hidden_dim = getattr(config, "projector_hidden_dim", None) or llm_dim
54
  self.linear_1 = nn.Linear(in_dim, hidden_dim, bias=False)
55
  self.norm = LlamaRMSNorm(hidden_dim, eps=1e-6)
56
+ self.norm.weight.data.fill_(self._NORM_INIT)
57
  self.act = nn.GELU()
58
  self.linear_2 = nn.Linear(hidden_dim, llm_dim, bias=False)
59
  # Output norm aligns the projector's RMS with the LM's embed_tokens
60
+ # distribution. See _NORM_INIT comment above for the magnitude
61
+ # derivation.
 
62
  self.norm_2 = LlamaRMSNorm(llm_dim, eps=1e-6)
63
+ self.norm_2.weight.data.fill_(self._NORM_INIT)
64
 
65
  def get_output_length(self, input_length: int) -> int:
66
  """Calculate output sequence length given input length (matches GLM-ASR)."""