Update custom model files, README, and requirements
Browse files- asr_modeling.py +17 -0
asr_modeling.py
CHANGED
|
@@ -454,6 +454,23 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 454 |
|
| 455 |
# Replace <audio> token placeholders with audio embeddings using masked_scatter
|
| 456 |
audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 457 |
inputs_embeds = inputs_embeds.masked_scatter(
|
| 458 |
audio_token_mask.to(inputs_embeds.device),
|
| 459 |
audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
|
|
|
|
| 454 |
|
| 455 |
# Replace <audio> token placeholders with audio embeddings using masked_scatter
|
| 456 |
audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
|
| 457 |
+
num_audio_tokens = audio_token_mask.sum() // audio_token_mask.shape[-1]
|
| 458 |
+
num_audio_embeds = audio_embeds.shape[0]
|
| 459 |
+
|
| 460 |
+
# Handle mismatch between expected tokens and actual embeddings
|
| 461 |
+
if num_audio_embeds < num_audio_tokens:
|
| 462 |
+
# Pad audio embeddings with zeros if we have fewer than expected
|
| 463 |
+
padding = torch.zeros(
|
| 464 |
+
num_audio_tokens - num_audio_embeds,
|
| 465 |
+
audio_embeds.shape[-1],
|
| 466 |
+
device=audio_embeds.device,
|
| 467 |
+
dtype=audio_embeds.dtype,
|
| 468 |
+
)
|
| 469 |
+
audio_embeds = torch.cat([audio_embeds, padding], dim=0)
|
| 470 |
+
elif num_audio_embeds > num_audio_tokens:
|
| 471 |
+
# Truncate if we have more embeddings than tokens
|
| 472 |
+
audio_embeds = audio_embeds[:num_audio_tokens]
|
| 473 |
+
|
| 474 |
inputs_embeds = inputs_embeds.masked_scatter(
|
| 475 |
audio_token_mask.to(inputs_embeds.device),
|
| 476 |
audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
|