Update custom model files, README, and requirements
Browse files- asr_modeling.py +39 -32
asr_modeling.py
CHANGED
|
@@ -392,12 +392,15 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 392 |
self,
|
| 393 |
audio_features: torch.Tensor,
|
| 394 |
audio_attention_mask: torch.Tensor,
|
|
|
|
| 395 |
) -> torch.Tensor:
|
| 396 |
"""Encode audio and project to LLM embedding space.
|
| 397 |
|
| 398 |
Args:
|
| 399 |
audio_features: Mel spectrogram features (batch, n_mels, mel_len)
|
| 400 |
audio_attention_mask: Mask indicating real vs padded mel frames (batch, mel_len)
|
|
|
|
|
|
|
| 401 |
|
| 402 |
Returns:
|
| 403 |
Flattened audio embeddings of shape (total_audio_tokens, hidden_dim).
|
|
@@ -406,24 +409,40 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 406 |
encoder_out = self.audio_tower(input_features=audio_features)
|
| 407 |
hidden_states = encoder_out.last_hidden_state
|
| 408 |
|
| 409 |
-
# Compute per-sample encoder output lengths using conv formulas
|
| 410 |
-
encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
|
| 411 |
-
|
| 412 |
# Project to LLM space
|
| 413 |
audio_embeds = self.projector(hidden_states)
|
| 414 |
|
| 415 |
-
#
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
|
| 421 |
-
#
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
|
| 428 |
def forward(
|
| 429 |
self,
|
|
@@ -449,28 +468,16 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 449 |
if self.training and self.spec_augment is not None:
|
| 450 |
input_features = self.spec_augment(input_features)
|
| 451 |
|
|
|
|
|
|
|
|
|
|
| 452 |
# Encode audio -> flattened (total_audio_tokens, hidden_dim)
|
| 453 |
-
audio_embeds = self._encode_audio(
|
|
|
|
|
|
|
| 454 |
|
| 455 |
# Replace <audio> token placeholders with audio embeddings using masked_scatter
|
| 456 |
audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
|
| 457 |
-
num_audio_tokens = audio_token_mask.sum() // audio_token_mask.shape[-1]
|
| 458 |
-
num_audio_embeds = audio_embeds.shape[0]
|
| 459 |
-
|
| 460 |
-
# Handle mismatch between expected tokens and actual embeddings
|
| 461 |
-
if num_audio_embeds < num_audio_tokens:
|
| 462 |
-
# Pad audio embeddings with zeros if we have fewer than expected
|
| 463 |
-
padding = torch.zeros(
|
| 464 |
-
num_audio_tokens - num_audio_embeds,
|
| 465 |
-
audio_embeds.shape[-1],
|
| 466 |
-
device=audio_embeds.device,
|
| 467 |
-
dtype=audio_embeds.dtype,
|
| 468 |
-
)
|
| 469 |
-
audio_embeds = torch.cat([audio_embeds, padding], dim=0)
|
| 470 |
-
elif num_audio_embeds > num_audio_tokens:
|
| 471 |
-
# Truncate if we have more embeddings than tokens
|
| 472 |
-
audio_embeds = audio_embeds[:num_audio_tokens]
|
| 473 |
-
|
| 474 |
inputs_embeds = inputs_embeds.masked_scatter(
|
| 475 |
audio_token_mask.to(inputs_embeds.device),
|
| 476 |
audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
|
|
|
|
| 392 |
self,
|
| 393 |
audio_features: torch.Tensor,
|
| 394 |
audio_attention_mask: torch.Tensor,
|
| 395 |
+
expected_token_counts: torch.Tensor | None = None,
|
| 396 |
) -> torch.Tensor:
|
| 397 |
"""Encode audio and project to LLM embedding space.
|
| 398 |
|
| 399 |
Args:
|
| 400 |
audio_features: Mel spectrogram features (batch, n_mels, mel_len)
|
| 401 |
audio_attention_mask: Mask indicating real vs padded mel frames (batch, mel_len)
|
| 402 |
+
expected_token_counts: Expected number of audio tokens per sample from input_ids.
|
| 403 |
+
If provided, output will match these counts exactly (padding/truncating as needed).
|
| 404 |
|
| 405 |
Returns:
|
| 406 |
Flattened audio embeddings of shape (total_audio_tokens, hidden_dim).
|
|
|
|
| 409 |
encoder_out = self.audio_tower(input_features=audio_features)
|
| 410 |
hidden_states = encoder_out.last_hidden_state
|
| 411 |
|
|
|
|
|
|
|
|
|
|
| 412 |
# Project to LLM space
|
| 413 |
audio_embeds = self.projector(hidden_states)
|
| 414 |
|
| 415 |
+
# Use expected token counts if provided (from input_ids), otherwise compute from audio
|
| 416 |
+
if expected_token_counts is not None:
|
| 417 |
+
token_counts = expected_token_counts
|
| 418 |
+
else:
|
| 419 |
+
# Compute per-sample encoder output lengths using conv formulas
|
| 420 |
+
encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
|
| 421 |
+
token_counts = torch.tensor(
|
| 422 |
+
[self.projector.get_output_length(int(length.item())) for length in encoder_lengths],
|
| 423 |
+
device=audio_embeds.device,
|
| 424 |
+
)
|
| 425 |
|
| 426 |
+
# Extract embeddings matching expected token counts per sample
|
| 427 |
+
batch_size = audio_embeds.shape[0]
|
| 428 |
+
hidden_dim = audio_embeds.shape[2]
|
| 429 |
+
|
| 430 |
+
result_embeds = []
|
| 431 |
+
for i in range(batch_size):
|
| 432 |
+
count = int(token_counts[i].item())
|
| 433 |
+
sample_embeds = audio_embeds[i, :count, :] # Take first 'count' embeddings
|
| 434 |
+
# Pad with zeros if we don't have enough embeddings
|
| 435 |
+
if sample_embeds.shape[0] < count:
|
| 436 |
+
padding = torch.zeros(
|
| 437 |
+
count - sample_embeds.shape[0],
|
| 438 |
+
hidden_dim,
|
| 439 |
+
device=audio_embeds.device,
|
| 440 |
+
dtype=audio_embeds.dtype,
|
| 441 |
+
)
|
| 442 |
+
sample_embeds = torch.cat([sample_embeds, padding], dim=0)
|
| 443 |
+
result_embeds.append(sample_embeds)
|
| 444 |
+
|
| 445 |
+
return torch.cat(result_embeds, dim=0)
|
| 446 |
|
| 447 |
def forward(
|
| 448 |
self,
|
|
|
|
| 468 |
if self.training and self.spec_augment is not None:
|
| 469 |
input_features = self.spec_augment(input_features)
|
| 470 |
|
| 471 |
+
# Count expected audio tokens from input_ids (ground truth from collator)
|
| 472 |
+
audio_token_counts = (input_ids == self.audio_token_id).sum(dim=-1)
|
| 473 |
+
|
| 474 |
# Encode audio -> flattened (total_audio_tokens, hidden_dim)
|
| 475 |
+
audio_embeds = self._encode_audio(
|
| 476 |
+
input_features, audio_attention_mask, audio_token_counts
|
| 477 |
+
)
|
| 478 |
|
| 479 |
# Replace <audio> token placeholders with audio embeddings using masked_scatter
|
| 480 |
audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 481 |
inputs_embeds = inputs_embeds.masked_scatter(
|
| 482 |
audio_token_mask.to(inputs_embeds.device),
|
| 483 |
audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
|