mazesmazes commited on
Commit
02b8ee7
·
verified ·
1 Parent(s): 8a4ea40

Update custom model files, README, and requirements

Browse files
Files changed (1) hide show
  1. asr_modeling.py +17 -0
asr_modeling.py CHANGED
@@ -454,6 +454,23 @@ class ASRModel(PreTrainedModel, GenerationMixin):
454
 
455
  # Replace <audio> token placeholders with audio embeddings using masked_scatter
456
  audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
  inputs_embeds = inputs_embeds.masked_scatter(
458
  audio_token_mask.to(inputs_embeds.device),
459
  audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
 
454
 
455
  # Replace <audio> token placeholders with audio embeddings using masked_scatter
456
  audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
457
+ num_audio_tokens = audio_token_mask.sum() // audio_token_mask.shape[-1]
458
+ num_audio_embeds = audio_embeds.shape[0]
459
+
460
+ # Handle mismatch between expected tokens and actual embeddings
461
+ if num_audio_embeds < num_audio_tokens:
462
+ # Pad audio embeddings with zeros if we have fewer than expected
463
+ padding = torch.zeros(
464
+ num_audio_tokens - num_audio_embeds,
465
+ audio_embeds.shape[-1],
466
+ device=audio_embeds.device,
467
+ dtype=audio_embeds.dtype,
468
+ )
469
+ audio_embeds = torch.cat([audio_embeds, padding], dim=0)
470
+ elif num_audio_embeds > num_audio_tokens:
471
+ # Truncate if we have more embeddings than tokens
472
+ audio_embeds = audio_embeds[:num_audio_tokens]
473
+
474
  inputs_embeds = inputs_embeds.masked_scatter(
475
  audio_token_mask.to(inputs_embeds.device),
476
  audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),