mazesmazes commited on
Commit
a4007a7
·
verified ·
1 Parent(s): 02b8ee7

Update custom model files, README, and requirements

Browse files
Files changed (1) hide show
  1. asr_modeling.py +39 -32
asr_modeling.py CHANGED
@@ -392,12 +392,15 @@ class ASRModel(PreTrainedModel, GenerationMixin):
392
  self,
393
  audio_features: torch.Tensor,
394
  audio_attention_mask: torch.Tensor,
 
395
  ) -> torch.Tensor:
396
  """Encode audio and project to LLM embedding space.
397
 
398
  Args:
399
  audio_features: Mel spectrogram features (batch, n_mels, mel_len)
400
  audio_attention_mask: Mask indicating real vs padded mel frames (batch, mel_len)
 
 
401
 
402
  Returns:
403
  Flattened audio embeddings of shape (total_audio_tokens, hidden_dim).
@@ -406,24 +409,40 @@ class ASRModel(PreTrainedModel, GenerationMixin):
406
  encoder_out = self.audio_tower(input_features=audio_features)
407
  hidden_states = encoder_out.last_hidden_state
408
 
409
- # Compute per-sample encoder output lengths using conv formulas
410
- encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
411
-
412
  # Project to LLM space
413
  audio_embeds = self.projector(hidden_states)
414
 
415
- # Compute per-sample projector output lengths
416
- projector_lengths = torch.tensor(
417
- [self.projector.get_output_length(int(length.item())) for length in encoder_lengths],
418
- device=audio_embeds.device,
419
- )
 
 
 
 
 
420
 
421
- # Create valid mask for variable-length samples and extract only real embeddings
422
- max_len = audio_embeds.shape[1]
423
- valid_mask = (
424
- torch.arange(max_len, device=audio_embeds.device)[None, :] < projector_lengths[:, None]
425
- )
426
- return audio_embeds[valid_mask]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
 
428
  def forward(
429
  self,
@@ -449,28 +468,16 @@ class ASRModel(PreTrainedModel, GenerationMixin):
449
  if self.training and self.spec_augment is not None:
450
  input_features = self.spec_augment(input_features)
451
 
 
 
 
452
  # Encode audio -> flattened (total_audio_tokens, hidden_dim)
453
- audio_embeds = self._encode_audio(input_features, audio_attention_mask)
 
 
454
 
455
  # Replace <audio> token placeholders with audio embeddings using masked_scatter
456
  audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
457
- num_audio_tokens = audio_token_mask.sum() // audio_token_mask.shape[-1]
458
- num_audio_embeds = audio_embeds.shape[0]
459
-
460
- # Handle mismatch between expected tokens and actual embeddings
461
- if num_audio_embeds < num_audio_tokens:
462
- # Pad audio embeddings with zeros if we have fewer than expected
463
- padding = torch.zeros(
464
- num_audio_tokens - num_audio_embeds,
465
- audio_embeds.shape[-1],
466
- device=audio_embeds.device,
467
- dtype=audio_embeds.dtype,
468
- )
469
- audio_embeds = torch.cat([audio_embeds, padding], dim=0)
470
- elif num_audio_embeds > num_audio_tokens:
471
- # Truncate if we have more embeddings than tokens
472
- audio_embeds = audio_embeds[:num_audio_tokens]
473
-
474
  inputs_embeds = inputs_embeds.masked_scatter(
475
  audio_token_mask.to(inputs_embeds.device),
476
  audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
 
392
  self,
393
  audio_features: torch.Tensor,
394
  audio_attention_mask: torch.Tensor,
395
+ expected_token_counts: torch.Tensor | None = None,
396
  ) -> torch.Tensor:
397
  """Encode audio and project to LLM embedding space.
398
 
399
  Args:
400
  audio_features: Mel spectrogram features (batch, n_mels, mel_len)
401
  audio_attention_mask: Mask indicating real vs padded mel frames (batch, mel_len)
402
+ expected_token_counts: Expected number of audio tokens per sample from input_ids.
403
+ If provided, output will match these counts exactly (padding/truncating as needed).
404
 
405
  Returns:
406
  Flattened audio embeddings of shape (total_audio_tokens, hidden_dim).
 
409
  encoder_out = self.audio_tower(input_features=audio_features)
410
  hidden_states = encoder_out.last_hidden_state
411
 
 
 
 
412
  # Project to LLM space
413
  audio_embeds = self.projector(hidden_states)
414
 
415
+ # Use expected token counts if provided (from input_ids), otherwise compute from audio
416
+ if expected_token_counts is not None:
417
+ token_counts = expected_token_counts
418
+ else:
419
+ # Compute per-sample encoder output lengths using conv formulas
420
+ encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
421
+ token_counts = torch.tensor(
422
+ [self.projector.get_output_length(int(length.item())) for length in encoder_lengths],
423
+ device=audio_embeds.device,
424
+ )
425
 
426
+ # Extract embeddings matching expected token counts per sample
427
+ batch_size = audio_embeds.shape[0]
428
+ hidden_dim = audio_embeds.shape[2]
429
+
430
+ result_embeds = []
431
+ for i in range(batch_size):
432
+ count = int(token_counts[i].item())
433
+ sample_embeds = audio_embeds[i, :count, :] # Take first 'count' embeddings
434
+ # Pad with zeros if we don't have enough embeddings
435
+ if sample_embeds.shape[0] < count:
436
+ padding = torch.zeros(
437
+ count - sample_embeds.shape[0],
438
+ hidden_dim,
439
+ device=audio_embeds.device,
440
+ dtype=audio_embeds.dtype,
441
+ )
442
+ sample_embeds = torch.cat([sample_embeds, padding], dim=0)
443
+ result_embeds.append(sample_embeds)
444
+
445
+ return torch.cat(result_embeds, dim=0)
446
 
447
  def forward(
448
  self,
 
468
  if self.training and self.spec_augment is not None:
469
  input_features = self.spec_augment(input_features)
470
 
471
+ # Count expected audio tokens from input_ids (ground truth from collator)
472
+ audio_token_counts = (input_ids == self.audio_token_id).sum(dim=-1)
473
+
474
  # Encode audio -> flattened (total_audio_tokens, hidden_dim)
475
+ audio_embeds = self._encode_audio(
476
+ input_features, audio_attention_mask, audio_token_counts
477
+ )
478
 
479
  # Replace <audio> token placeholders with audio embeddings using masked_scatter
480
  audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481
  inputs_embeds = inputs_embeds.masked_scatter(
482
  audio_token_mask.to(inputs_embeds.device),
483
  audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),