mazesmazes commited on
Commit
71dff33
·
verified ·
1 Parent(s): acfacd1

Training in progress - step 500

Browse files
Files changed (2) hide show
  1. asr_modeling.py +6 -19
  2. projectors.py +1 -2
asr_modeling.py CHANGED
@@ -269,11 +269,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
269
  """Only save trainable projector weights."""
270
  return {f"projector.{k}": v for k, v in self.projector.state_dict().items()}
271
 
272
- def _apply_specaugment(
273
- self,
274
- input_features: torch.Tensor,
275
- attention_mask: Optional[torch.Tensor] = None,
276
- ) -> torch.Tensor:
277
  if not getattr(self.config, "use_specaugment", False):
278
  return input_features
279
 
@@ -294,7 +290,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
294
  (batch_size, sequence_length),
295
  mask_prob=mask_time_prob,
296
  mask_length=mask_time_length,
297
- attention_mask=attention_mask,
298
  min_masks=2,
299
  )
300
  mask_time_indices = torch.tensor(
@@ -321,22 +316,16 @@ class ASRModel(PreTrainedModel, GenerationMixin):
321
 
322
  return input_features
323
 
324
- def _encode_audio(
325
- self,
326
- audio_features: torch.Tensor,
327
- audio_attention_mask: Optional[torch.Tensor] = None,
328
- ) -> torch.Tensor:
329
  """Encode audio and project to LLM embedding space.
330
 
331
  Returns flattened audio embeddings of shape (total_audio_tokens, hidden_dim).
332
  """
333
  # Apply SpecAugment during training (before encoding)
334
- audio_features = self._apply_specaugment(audio_features, audio_attention_mask)
335
 
336
  with torch.no_grad():
337
- encoder_out = self.audio_tower(
338
- input_features=audio_features, attention_mask=audio_attention_mask
339
- )
340
  hidden_states = encoder_out.last_hidden_state
341
 
342
  audio_embeds = self.projector(hidden_states)
@@ -356,7 +345,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
356
  labels: Optional[torch.Tensor] = None,
357
  use_cache: Optional[bool] = None,
358
  cache_position: Optional[torch.Tensor] = None,
359
- audio_attention_mask: Optional[torch.Tensor] = None,
360
  **kwargs,
361
  ) -> CausalLMOutputWithPast:
362
  """Forward pass for training and inference."""
@@ -366,7 +354,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
366
 
367
  if input_features is not None and input_ids is not None:
368
  # Encode audio -> flattened (total_audio_tokens, hidden_dim)
369
- audio_embeds = self._encode_audio(input_features, audio_attention_mask)
370
 
371
  # Replace <audio> token placeholders with audio embeddings using masked_scatter
372
  audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
@@ -427,7 +415,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
427
  input_ids: Optional[torch.Tensor] = None,
428
  input_features: Optional[torch.Tensor] = None,
429
  attention_mask: Optional[torch.Tensor] = None,
430
- audio_attention_mask: Optional[torch.Tensor] = None,
431
  system_prompt: Optional[str] = None,
432
  **generate_kwargs,
433
  ) -> torch.Tensor:
@@ -444,7 +431,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
444
  batch_size = input_features.shape[0]
445
 
446
  # Encode audio -> flattened embeddings
447
- audio_embeds = self._encode_audio(input_features, audio_attention_mask)
448
 
449
  # If input_ids not provided, build prompt with correct number of audio tokens
450
  if input_ids is None:
 
269
  """Only save trainable projector weights."""
270
  return {f"projector.{k}": v for k, v in self.projector.state_dict().items()}
271
 
272
+ def _apply_specaugment(self, input_features: torch.Tensor) -> torch.Tensor:
 
 
 
 
273
  if not getattr(self.config, "use_specaugment", False):
274
  return input_features
275
 
 
290
  (batch_size, sequence_length),
291
  mask_prob=mask_time_prob,
292
  mask_length=mask_time_length,
 
293
  min_masks=2,
294
  )
295
  mask_time_indices = torch.tensor(
 
316
 
317
  return input_features
318
 
319
+ def _encode_audio(self, audio_features: torch.Tensor) -> torch.Tensor:
 
 
 
 
320
  """Encode audio and project to LLM embedding space.
321
 
322
  Returns flattened audio embeddings of shape (total_audio_tokens, hidden_dim).
323
  """
324
  # Apply SpecAugment during training (before encoding)
325
+ audio_features = self._apply_specaugment(audio_features)
326
 
327
  with torch.no_grad():
328
+ encoder_out = self.audio_tower(input_features=audio_features)
 
 
329
  hidden_states = encoder_out.last_hidden_state
330
 
331
  audio_embeds = self.projector(hidden_states)
 
345
  labels: Optional[torch.Tensor] = None,
346
  use_cache: Optional[bool] = None,
347
  cache_position: Optional[torch.Tensor] = None,
 
348
  **kwargs,
349
  ) -> CausalLMOutputWithPast:
350
  """Forward pass for training and inference."""
 
354
 
355
  if input_features is not None and input_ids is not None:
356
  # Encode audio -> flattened (total_audio_tokens, hidden_dim)
357
+ audio_embeds = self._encode_audio(input_features)
358
 
359
  # Replace <audio> token placeholders with audio embeddings using masked_scatter
360
  audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
 
415
  input_ids: Optional[torch.Tensor] = None,
416
  input_features: Optional[torch.Tensor] = None,
417
  attention_mask: Optional[torch.Tensor] = None,
 
418
  system_prompt: Optional[str] = None,
419
  **generate_kwargs,
420
  ) -> torch.Tensor:
 
431
  batch_size = input_features.shape[0]
432
 
433
  # Encode audio -> flattened embeddings
434
+ audio_embeds = self._encode_audio(input_features)
435
 
436
  # If input_ids not provided, build prompt with correct number of audio tokens
437
  if input_ids is None:
projectors.py CHANGED
@@ -680,14 +680,13 @@ class QFormerAudioProjector(nn.Module):
680
  effective_batch = batch_size * nblocks
681
  hidden_states = hidden_states.view(effective_batch, self.window_size, -1)
682
 
683
- # Expand queries to match batch size (Granite relies on broadcast, but CUDA has issues)
684
  query_embeds = self.query.expand(effective_batch, -1, -1)
685
 
686
  # QFormer cross-attention
687
  query_output = self.qformer(
688
  query_embeds=query_embeds,
689
  encoder_hidden_states=hidden_states,
690
- encoder_attention_mask=None,
691
  return_dict=True,
692
  )
693
 
 
680
  effective_batch = batch_size * nblocks
681
  hidden_states = hidden_states.view(effective_batch, self.window_size, -1)
682
 
683
+ # Expand queries to match batch size
684
  query_embeds = self.query.expand(effective_batch, -1, -1)
685
 
686
  # QFormer cross-attention
687
  query_output = self.qformer(
688
  query_embeds=query_embeds,
689
  encoder_hidden_states=hidden_states,
 
690
  return_dict=True,
691
  )
692