mazesmazes commited on
Commit
634ce62
·
verified ·
1 Parent(s): b729799

Training in progress - step 5000

Browse files
Files changed (2) hide show
  1. asr_modeling.py +39 -15
  2. model.safetensors +1 -1
asr_modeling.py CHANGED
@@ -392,12 +392,15 @@ class ASRModel(PreTrainedModel, GenerationMixin):
392
  self,
393
  audio_features: torch.Tensor,
394
  audio_attention_mask: torch.Tensor,
 
395
  ) -> torch.Tensor:
396
  """Encode audio and project to LLM embedding space.
397
 
398
  Args:
399
  audio_features: Mel spectrogram features (batch, n_mels, mel_len)
400
  audio_attention_mask: Mask indicating real vs padded mel frames (batch, mel_len)
 
 
401
 
402
  Returns:
403
  Flattened audio embeddings of shape (total_audio_tokens, hidden_dim).
@@ -406,24 +409,40 @@ class ASRModel(PreTrainedModel, GenerationMixin):
406
  encoder_out = self.audio_tower(input_features=audio_features)
407
  hidden_states = encoder_out.last_hidden_state
408
 
409
- # Compute per-sample encoder output lengths using conv formulas
410
- encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
411
-
412
  # Project to LLM space
413
  audio_embeds = self.projector(hidden_states)
414
 
415
- # Compute per-sample projector output lengths
416
- projector_lengths = torch.tensor(
417
- [self.projector.get_output_length(int(length.item())) for length in encoder_lengths],
418
- device=audio_embeds.device,
419
- )
 
 
 
 
 
420
 
421
- # Create valid mask for variable-length samples and extract only real embeddings
422
- max_len = audio_embeds.shape[1]
423
- valid_mask = (
424
- torch.arange(max_len, device=audio_embeds.device)[None, :] < projector_lengths[:, None]
425
- )
426
- return audio_embeds[valid_mask]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
 
428
  def forward(
429
  self,
@@ -449,8 +468,13 @@ class ASRModel(PreTrainedModel, GenerationMixin):
449
  if self.training and self.spec_augment is not None:
450
  input_features = self.spec_augment(input_features)
451
 
 
 
 
452
  # Encode audio -> flattened (total_audio_tokens, hidden_dim)
453
- audio_embeds = self._encode_audio(input_features, audio_attention_mask)
 
 
454
 
455
  # Replace <audio> token placeholders with audio embeddings using masked_scatter
456
  audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
 
392
  self,
393
  audio_features: torch.Tensor,
394
  audio_attention_mask: torch.Tensor,
395
+ expected_token_counts: torch.Tensor | None = None,
396
  ) -> torch.Tensor:
397
  """Encode audio and project to LLM embedding space.
398
 
399
  Args:
400
  audio_features: Mel spectrogram features (batch, n_mels, mel_len)
401
  audio_attention_mask: Mask indicating real vs padded mel frames (batch, mel_len)
402
+ expected_token_counts: Expected number of audio tokens per sample from input_ids.
403
+ If provided, output will match these counts exactly (padding/truncating as needed).
404
 
405
  Returns:
406
  Flattened audio embeddings of shape (total_audio_tokens, hidden_dim).
 
409
  encoder_out = self.audio_tower(input_features=audio_features)
410
  hidden_states = encoder_out.last_hidden_state
411
 
 
 
 
412
  # Project to LLM space
413
  audio_embeds = self.projector(hidden_states)
414
 
415
+ # Use expected token counts if provided (from input_ids), otherwise compute from audio
416
+ if expected_token_counts is not None:
417
+ token_counts = expected_token_counts
418
+ else:
419
+ # Compute per-sample encoder output lengths using conv formulas
420
+ encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
421
+ token_counts = torch.tensor(
422
+ [self.projector.get_output_length(int(length.item())) for length in encoder_lengths],
423
+ device=audio_embeds.device,
424
+ )
425
 
426
+ # Extract embeddings matching expected token counts per sample
427
+ batch_size = audio_embeds.shape[0]
428
+ hidden_dim = audio_embeds.shape[2]
429
+
430
+ result_embeds = []
431
+ for i in range(batch_size):
432
+ count = int(token_counts[i].item())
433
+ sample_embeds = audio_embeds[i, :count, :] # Take first 'count' embeddings
434
+ # Pad with zeros if we don't have enough embeddings
435
+ if sample_embeds.shape[0] < count:
436
+ padding = torch.zeros(
437
+ count - sample_embeds.shape[0],
438
+ hidden_dim,
439
+ device=audio_embeds.device,
440
+ dtype=audio_embeds.dtype,
441
+ )
442
+ sample_embeds = torch.cat([sample_embeds, padding], dim=0)
443
+ result_embeds.append(sample_embeds)
444
+
445
+ return torch.cat(result_embeds, dim=0)
446
 
447
  def forward(
448
  self,
 
468
  if self.training and self.spec_augment is not None:
469
  input_features = self.spec_augment(input_features)
470
 
471
+ # Count expected audio tokens from input_ids (ground truth from collator)
472
+ audio_token_counts = (input_ids == self.audio_token_id).sum(dim=-1)
473
+
474
  # Encode audio -> flattened (total_audio_tokens, hidden_dim)
475
+ audio_embeds = self._encode_audio(
476
+ input_features, audio_attention_mask, audio_token_counts
477
+ )
478
 
479
  # Replace <audio> token placeholders with audio embeddings using masked_scatter
480
  audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e1c29578f6e4473b5f6a25ba03515832cfc1c5698f516d02f7758722d09b7065
3
  size 58732960
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bdd57ee656ad1fbf90b00aacb15682654daadb76f7a34fc574f8ca94e35ef178
3
  size 58732960