mazesmazes commited on
Commit
646108e
·
verified ·
1 Parent(s): 78e998a

Training in progress - step 1000

Browse files
Files changed (5) hide show
  1. asr_config.py +8 -13
  2. asr_modeling.py +0 -56
  3. config.json +1 -2
  4. model.safetensors +1 -1
  5. projectors.py +7 -0
asr_config.py CHANGED
@@ -50,6 +50,13 @@ class ASRConfig(transformers.PretrainedConfig):
50
  projector_pool_stride: int = 4,
51
  downsample_rate: int = 5, # Granite default
52
  projector_hidden_dim: Optional[int] = None,
 
 
 
 
 
 
 
53
  projector_type: str = "mlp", # "mlp", "mosa", "moe", "qformer"
54
  # MoE-specific configuration
55
  num_experts: int = 4, # Number of experts in MoE projectors
@@ -69,17 +76,6 @@ class ASRConfig(transformers.PretrainedConfig):
69
  lora_target_modules: Optional[list] = None, # Default: all linear layers
70
  freeze_projector: bool = False, # True for Stage 2 (LoRA-only training)
71
  freeze_language_model: bool = True, # False = full decoder fine-tuning
72
- # Encoder-output time masking — SpecAugment-style time masking applied
73
- # AFTER the frozen encoder, BEFORE the projector. The actual SOTA-
74
- # equivalent regularizer for frozen-encoder projector training: mel-
75
- # side SpecAugment (the NeMo / OWSM default at ~F=2,T=10 for trainable
76
- # encoders) would push a frozen Whisper encoder OOD, so we instead
77
- # mask the encoder's output features and let the projector learn to
78
- # reconstruct missing time positions. Disabled by default (0 / 0.0);
79
- # canonical setting is num=5 masks of max_width_ratio=0.04 (up to
80
- # ~20% of encoder output time positions masked per sample).
81
- encoder_output_time_mask_num: int = 0,
82
- encoder_output_time_mask_max_width_ratio: float = 0.0,
83
  do_sample: bool = False,
84
  temperature: Optional[float] = None,
85
  top_p: Optional[float] = None,
@@ -127,6 +123,7 @@ class ASRConfig(transformers.PretrainedConfig):
127
  self.projector_pool_stride = projector_pool_stride
128
  self.downsample_rate = downsample_rate
129
  self.projector_hidden_dim = projector_hidden_dim
 
130
  self.projector_type = projector_type
131
  # MoE-specific configuration
132
  self.num_experts = num_experts
@@ -154,8 +151,6 @@ class ASRConfig(transformers.PretrainedConfig):
154
  ]
155
  self.freeze_projector = freeze_projector
156
  self.freeze_language_model = freeze_language_model
157
- self.encoder_output_time_mask_num = encoder_output_time_mask_num
158
- self.encoder_output_time_mask_max_width_ratio = encoder_output_time_mask_max_width_ratio
159
 
160
  explicit_generation_args = {
161
  "num_beams": num_beams,
 
50
  projector_pool_stride: int = 4,
51
  downsample_rate: int = 5, # Granite default
52
  projector_hidden_dim: Optional[int] = None,
53
+ # Projector dropout — applied between activation and the second
54
+ # linear in MLPAudioProjector. Matches Granite-Speech 4.1's
55
+ # Q-Former dropout (hidden_dropout_prob=0.1) used in its frozen-
56
+ # encoder + LoRA-LLM training stage. Default 0.0 for backward
57
+ # compatibility with existing checkpoints; experiment configs
58
+ # opt in to 0.1.
59
+ projector_dropout: float = 0.0,
60
  projector_type: str = "mlp", # "mlp", "mosa", "moe", "qformer"
61
  # MoE-specific configuration
62
  num_experts: int = 4, # Number of experts in MoE projectors
 
76
  lora_target_modules: Optional[list] = None, # Default: all linear layers
77
  freeze_projector: bool = False, # True for Stage 2 (LoRA-only training)
78
  freeze_language_model: bool = True, # False = full decoder fine-tuning
 
 
 
 
 
 
 
 
 
 
 
79
  do_sample: bool = False,
80
  temperature: Optional[float] = None,
81
  top_p: Optional[float] = None,
 
123
  self.projector_pool_stride = projector_pool_stride
124
  self.downsample_rate = downsample_rate
125
  self.projector_hidden_dim = projector_hidden_dim
126
+ self.projector_dropout = projector_dropout
127
  self.projector_type = projector_type
128
  # MoE-specific configuration
129
  self.num_experts = num_experts
 
151
  ]
152
  self.freeze_projector = freeze_projector
153
  self.freeze_language_model = freeze_language_model
 
 
154
 
155
  explicit_generation_args = {
156
  "num_beams": num_beams,
asr_modeling.py CHANGED
@@ -44,41 +44,6 @@ def _gather_audio_embeds(audio_embeds: torch.Tensor, token_counts: torch.Tensor)
44
  return audio_embeds[mask]
45
 
46
 
47
- def _time_mask_encoder_output(
48
- hidden_states: torch.Tensor,
49
- num_masks: int,
50
- max_width_ratio: float,
51
- ) -> torch.Tensor:
52
- """SpecAugment-style time masking on encoder output features.
53
-
54
- Zero-fills ``num_masks`` random contiguous time spans per sample. Each
55
- span has width sampled uniformly from ``[0, max_width]`` where
56
- ``max_width = max(1, int(time_len * max_width_ratio))``. Returns the
57
- input unchanged when ``num_masks <= 0`` or ``max_width_ratio <= 0``.
58
-
59
- Args:
60
- hidden_states: ``(batch, time, dim)`` encoder output.
61
- num_masks: Number of time-mask spans applied per sample.
62
- max_width_ratio: Maximum mask width as a fraction of ``time``.
63
- """
64
- if num_masks <= 0 or max_width_ratio <= 0.0:
65
- return hidden_states
66
- batch, time_len, _ = hidden_states.shape
67
- max_width = max(1, int(time_len * max_width_ratio))
68
- device = hidden_states.device
69
-
70
- widths = torch.randint(0, max_width + 1, (batch, num_masks), device=device)
71
- max_starts = (time_len - widths).clamp(min=1)
72
- starts = (torch.rand(batch, num_masks, device=device) * max_starts).long()
73
-
74
- indices = torch.arange(time_len, device=device).view(1, 1, -1)
75
- starts_e = starts.unsqueeze(-1)
76
- ends_e = (starts + widths).unsqueeze(-1)
77
- in_any_mask = ((indices >= starts_e) & (indices < ends_e)).any(dim=1)
78
- keep = (~in_any_mask).to(dtype=hidden_states.dtype).unsqueeze(-1)
79
- return hidden_states * keep
80
-
81
-
82
  class ASRModel(PreTrainedModel, GenerationMixin):
83
  """Audio-to-text model combining an audio encoder, projector, and language model."""
84
 
@@ -466,26 +431,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
466
  self.config.encoder_conv_layers,
467
  )
468
 
469
- def _apply_encoder_output_time_masking(self, hidden_states: torch.Tensor) -> torch.Tensor:
470
- """SpecAugment-style time masking on encoder OUTPUT features.
471
-
472
- For frozen-encoder projector training, mel-side SpecAugment would
473
- push the encoder OOD (it was never trained with masked mels), so
474
- we mask the encoder output instead. The projector learns to be
475
- robust to missing encoder-output time positions. Disabled outside
476
- training; otherwise delegates to ``_time_mask_encoder_output`` with
477
- the config knobs.
478
- """
479
- if not self.training:
480
- return hidden_states
481
- return _time_mask_encoder_output(
482
- hidden_states,
483
- num_masks=int(getattr(self.config, "encoder_output_time_mask_num", 0)),
484
- max_width_ratio=float(
485
- getattr(self.config, "encoder_output_time_mask_max_width_ratio", 0.0)
486
- ),
487
- )
488
-
489
  def _encode_audio(
490
  self,
491
  audio_features: torch.Tensor,
@@ -504,7 +449,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
504
  encoder_out = self.audio_tower(input_features=audio_features)
505
  hidden_states = encoder_out.last_hidden_state
506
 
507
- hidden_states = self._apply_encoder_output_time_masking(hidden_states)
508
  audio_embeds = self.projector(hidden_states)
509
 
510
  token_counts = expected_token_counts.to(device=audio_embeds.device, dtype=torch.long)
 
44
  return audio_embeds[mask]
45
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  class ASRModel(PreTrainedModel, GenerationMixin):
48
  """Audio-to-text model combining an audio encoder, projector, and language model."""
49
 
 
431
  self.config.encoder_conv_layers,
432
  )
433
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
  def _encode_audio(
435
  self,
436
  audio_features: torch.Tensor,
 
449
  encoder_out = self.audio_tower(input_features=audio_features)
450
  hidden_states = encoder_out.last_hidden_state
451
 
 
452
  audio_embeds = self.projector(hidden_states)
453
 
454
  token_counts = expected_token_counts.to(device=audio_embeds.device, dtype=torch.long)
config.json CHANGED
@@ -234,8 +234,6 @@
234
  ]
235
  ],
236
  "encoder_dim": 1280,
237
- "encoder_output_time_mask_max_width_ratio": 0.04,
238
- "encoder_output_time_mask_num": 5,
239
  "eos_token_id": 151645,
240
  "freeze_language_model": false,
241
  "freeze_projector": false,
@@ -264,6 +262,7 @@
264
  "pad_token_id": 151643,
265
  "pipeline_tag": "automatic-speech-recognition",
266
  "pretrained_model_path": "mazesmazes/tiny-audio-next",
 
267
  "projector_hidden_dim": 2048,
268
  "projector_pool_stride": 4,
269
  "projector_type": "mlp",
 
234
  ]
235
  ],
236
  "encoder_dim": 1280,
 
 
237
  "eos_token_id": 151645,
238
  "freeze_language_model": false,
239
  "freeze_projector": false,
 
262
  "pad_token_id": 151643,
263
  "pipeline_tag": "automatic-speech-recognition",
264
  "pretrained_model_path": "mazesmazes/tiny-audio-next",
265
+ "projector_dropout": 0.1,
266
  "projector_hidden_dim": 2048,
267
  "projector_pool_stride": 4,
268
  "projector_type": "mlp",
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f20696bf71ed993d8320bb51ae89bd1f3dee0392c2f69fd0578b7095e4ee9d2b
3
  size 2433494416
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:97b1eda3c22a7e702033952c30ab1de35166bb22100f243c283d980978e1a8bd
3
  size 2433494416
projectors.py CHANGED
@@ -55,6 +55,12 @@ class MLPAudioProjector(nn.Module):
55
  self.norm = LlamaRMSNorm(hidden_dim, eps=1e-6)
56
  self.norm.weight.data.fill_(self._NORM_INIT)
57
  self.act = nn.GELU()
 
 
 
 
 
 
58
  self.linear_2 = nn.Linear(hidden_dim, llm_dim, bias=False)
59
  # Output norm aligns the projector's RMS with the LM's embed_tokens
60
  # distribution. See _NORM_INIT comment above for the magnitude
@@ -80,6 +86,7 @@ class MLPAudioProjector(nn.Module):
80
  x = self.linear_1(x)
81
  x = self.norm(x)
82
  x = self.act(x)
 
83
  x = self.linear_2(x)
84
  return self.norm_2(x)
85
 
 
55
  self.norm = LlamaRMSNorm(hidden_dim, eps=1e-6)
56
  self.norm.weight.data.fill_(self._NORM_INIT)
57
  self.act = nn.GELU()
58
+ # Dropout matches Granite-Speech 4.1's Q-Former hidden_dropout_prob=0.1
59
+ # in its frozen-encoder modality-alignment stage — the closest
60
+ # published precedent for our regime. Default 0.0 in config means
61
+ # nn.Dropout(0.0) is a no-op for existing experiments.
62
+ projector_dropout = float(getattr(config, "projector_dropout", 0.0))
63
+ self.dropout = nn.Dropout(projector_dropout)
64
  self.linear_2 = nn.Linear(hidden_dim, llm_dim, bias=False)
65
  # Output norm aligns the projector's RMS with the LM's embed_tokens
66
  # distribution. See _NORM_INIT comment above for the magnitude
 
86
  x = self.linear_1(x)
87
  x = self.norm(x)
88
  x = self.act(x)
89
+ x = self.dropout(x)
90
  x = self.linear_2(x)
91
  return self.norm_2(x)
92