mazesmazes commited on
Commit
ea0288b
·
verified ·
1 Parent(s): a77f479

Training in progress - step 1000

Browse files
Files changed (4) hide show
  1. asr_config.py +13 -0
  2. asr_modeling.py +56 -0
  3. config.json +2 -0
  4. model.safetensors +1 -1
asr_config.py CHANGED
@@ -69,6 +69,17 @@ class ASRConfig(transformers.PretrainedConfig):
69
  lora_target_modules: Optional[list] = None, # Default: all linear layers
70
  freeze_projector: bool = False, # True for Stage 2 (LoRA-only training)
71
  freeze_language_model: bool = True, # False = full decoder fine-tuning
 
 
 
 
 
 
 
 
 
 
 
72
  do_sample: bool = False,
73
  temperature: Optional[float] = None,
74
  top_p: Optional[float] = None,
@@ -143,6 +154,8 @@ class ASRConfig(transformers.PretrainedConfig):
143
  ]
144
  self.freeze_projector = freeze_projector
145
  self.freeze_language_model = freeze_language_model
 
 
146
 
147
  explicit_generation_args = {
148
  "num_beams": num_beams,
 
69
  lora_target_modules: Optional[list] = None, # Default: all linear layers
70
  freeze_projector: bool = False, # True for Stage 2 (LoRA-only training)
71
  freeze_language_model: bool = True, # False = full decoder fine-tuning
72
+ # Encoder-output time masking — SpecAugment-style time masking applied
73
+ # AFTER the frozen encoder, BEFORE the projector. The actual SOTA-
74
+ # equivalent regularizer for frozen-encoder projector training: mel-
75
+ # side SpecAugment (the NeMo / OWSM default at ~F=2,T=10 for trainable
76
+ # encoders) would push a frozen Whisper encoder OOD, so we instead
77
+ # mask the encoder's output features and let the projector learn to
78
+ # reconstruct missing time positions. Disabled by default (0 / 0.0);
79
+ # canonical setting is num=5 masks of max_width_ratio=0.04 (up to
80
+ # ~20% of encoder output time positions masked per sample).
81
+ encoder_output_time_mask_num: int = 0,
82
+ encoder_output_time_mask_max_width_ratio: float = 0.0,
83
  do_sample: bool = False,
84
  temperature: Optional[float] = None,
85
  top_p: Optional[float] = None,
 
154
  ]
155
  self.freeze_projector = freeze_projector
156
  self.freeze_language_model = freeze_language_model
157
+ self.encoder_output_time_mask_num = encoder_output_time_mask_num
158
+ self.encoder_output_time_mask_max_width_ratio = encoder_output_time_mask_max_width_ratio
159
 
160
  explicit_generation_args = {
161
  "num_beams": num_beams,
asr_modeling.py CHANGED
@@ -44,6 +44,41 @@ def _gather_audio_embeds(audio_embeds: torch.Tensor, token_counts: torch.Tensor)
44
  return audio_embeds[mask]
45
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  class ASRModel(PreTrainedModel, GenerationMixin):
48
  """Audio-to-text model combining an audio encoder, projector, and language model."""
49
 
@@ -431,6 +466,26 @@ class ASRModel(PreTrainedModel, GenerationMixin):
431
  self.config.encoder_conv_layers,
432
  )
433
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
  def _encode_audio(
435
  self,
436
  audio_features: torch.Tensor,
@@ -449,6 +504,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
449
  encoder_out = self.audio_tower(input_features=audio_features)
450
  hidden_states = encoder_out.last_hidden_state
451
 
 
452
  audio_embeds = self.projector(hidden_states)
453
 
454
  token_counts = expected_token_counts.to(device=audio_embeds.device, dtype=torch.long)
 
44
  return audio_embeds[mask]
45
 
46
 
47
+ def _time_mask_encoder_output(
48
+ hidden_states: torch.Tensor,
49
+ num_masks: int,
50
+ max_width_ratio: float,
51
+ ) -> torch.Tensor:
52
+ """SpecAugment-style time masking on encoder output features.
53
+
54
+ Zero-fills ``num_masks`` random contiguous time spans per sample. Each
55
+ span has width sampled uniformly from ``[0, max_width]`` where
56
+ ``max_width = max(1, int(time_len * max_width_ratio))``. Returns the
57
+ input unchanged when ``num_masks <= 0`` or ``max_width_ratio <= 0``.
58
+
59
+ Args:
60
+ hidden_states: ``(batch, time, dim)`` encoder output.
61
+ num_masks: Number of time-mask spans applied per sample.
62
+ max_width_ratio: Maximum mask width as a fraction of ``time``.
63
+ """
64
+ if num_masks <= 0 or max_width_ratio <= 0.0:
65
+ return hidden_states
66
+ batch, time_len, _ = hidden_states.shape
67
+ max_width = max(1, int(time_len * max_width_ratio))
68
+ device = hidden_states.device
69
+
70
+ widths = torch.randint(0, max_width + 1, (batch, num_masks), device=device)
71
+ max_starts = (time_len - widths).clamp(min=1)
72
+ starts = (torch.rand(batch, num_masks, device=device) * max_starts).long()
73
+
74
+ indices = torch.arange(time_len, device=device).view(1, 1, -1)
75
+ starts_e = starts.unsqueeze(-1)
76
+ ends_e = (starts + widths).unsqueeze(-1)
77
+ in_any_mask = ((indices >= starts_e) & (indices < ends_e)).any(dim=1)
78
+ keep = (~in_any_mask).to(dtype=hidden_states.dtype).unsqueeze(-1)
79
+ return hidden_states * keep
80
+
81
+
82
  class ASRModel(PreTrainedModel, GenerationMixin):
83
  """Audio-to-text model combining an audio encoder, projector, and language model."""
84
 
 
466
  self.config.encoder_conv_layers,
467
  )
468
 
469
+ def _apply_encoder_output_time_masking(self, hidden_states: torch.Tensor) -> torch.Tensor:
470
+ """SpecAugment-style time masking on encoder OUTPUT features.
471
+
472
+ For frozen-encoder projector training, mel-side SpecAugment would
473
+ push the encoder OOD (it was never trained with masked mels), so
474
+ we mask the encoder output instead. The projector learns to be
475
+ robust to missing encoder-output time positions. Disabled outside
476
+ training; otherwise delegates to ``_time_mask_encoder_output`` with
477
+ the config knobs.
478
+ """
479
+ if not self.training:
480
+ return hidden_states
481
+ return _time_mask_encoder_output(
482
+ hidden_states,
483
+ num_masks=int(getattr(self.config, "encoder_output_time_mask_num", 0)),
484
+ max_width_ratio=float(
485
+ getattr(self.config, "encoder_output_time_mask_max_width_ratio", 0.0)
486
+ ),
487
+ )
488
+
489
  def _encode_audio(
490
  self,
491
  audio_features: torch.Tensor,
 
504
  encoder_out = self.audio_tower(input_features=audio_features)
505
  hidden_states = encoder_out.last_hidden_state
506
 
507
+ hidden_states = self._apply_encoder_output_time_masking(hidden_states)
508
  audio_embeds = self.projector(hidden_states)
509
 
510
  token_counts = expected_token_counts.to(device=audio_embeds.device, dtype=torch.long)
config.json CHANGED
@@ -234,6 +234,8 @@
234
  ]
235
  ],
236
  "encoder_dim": 1280,
 
 
237
  "eos_token_id": 151645,
238
  "freeze_language_model": false,
239
  "freeze_projector": false,
 
234
  ]
235
  ],
236
  "encoder_dim": 1280,
237
+ "encoder_output_time_mask_max_width_ratio": 0.04,
238
+ "encoder_output_time_mask_num": 5,
239
  "eos_token_id": 151645,
240
  "freeze_language_model": false,
241
  "freeze_projector": false,
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:76e8c05ce5ce83ffb34c9378af242d2fc9977feadf8908595777ed0caad01d59
3
  size 2433494416
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6acdb5516fe096018c81fec845acf24541064d089ab33133cb57a431aaefc229
3
  size 2433494416