Training in progress - step 500
Browse files- asr_modeling.py +6 -19
- projectors.py +1 -2
asr_modeling.py
CHANGED
|
@@ -269,11 +269,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 269 |
"""Only save trainable projector weights."""
|
| 270 |
return {f"projector.{k}": v for k, v in self.projector.state_dict().items()}
|
| 271 |
|
| 272 |
-
def _apply_specaugment(
|
| 273 |
-
self,
|
| 274 |
-
input_features: torch.Tensor,
|
| 275 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 276 |
-
) -> torch.Tensor:
|
| 277 |
if not getattr(self.config, "use_specaugment", False):
|
| 278 |
return input_features
|
| 279 |
|
|
@@ -294,7 +290,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 294 |
(batch_size, sequence_length),
|
| 295 |
mask_prob=mask_time_prob,
|
| 296 |
mask_length=mask_time_length,
|
| 297 |
-
attention_mask=attention_mask,
|
| 298 |
min_masks=2,
|
| 299 |
)
|
| 300 |
mask_time_indices = torch.tensor(
|
|
@@ -321,22 +316,16 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 321 |
|
| 322 |
return input_features
|
| 323 |
|
| 324 |
-
def _encode_audio(
|
| 325 |
-
self,
|
| 326 |
-
audio_features: torch.Tensor,
|
| 327 |
-
audio_attention_mask: Optional[torch.Tensor] = None,
|
| 328 |
-
) -> torch.Tensor:
|
| 329 |
"""Encode audio and project to LLM embedding space.
|
| 330 |
|
| 331 |
Returns flattened audio embeddings of shape (total_audio_tokens, hidden_dim).
|
| 332 |
"""
|
| 333 |
# Apply SpecAugment during training (before encoding)
|
| 334 |
-
audio_features = self._apply_specaugment(audio_features
|
| 335 |
|
| 336 |
with torch.no_grad():
|
| 337 |
-
encoder_out = self.audio_tower(
|
| 338 |
-
input_features=audio_features, attention_mask=audio_attention_mask
|
| 339 |
-
)
|
| 340 |
hidden_states = encoder_out.last_hidden_state
|
| 341 |
|
| 342 |
audio_embeds = self.projector(hidden_states)
|
|
@@ -356,7 +345,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 356 |
labels: Optional[torch.Tensor] = None,
|
| 357 |
use_cache: Optional[bool] = None,
|
| 358 |
cache_position: Optional[torch.Tensor] = None,
|
| 359 |
-
audio_attention_mask: Optional[torch.Tensor] = None,
|
| 360 |
**kwargs,
|
| 361 |
) -> CausalLMOutputWithPast:
|
| 362 |
"""Forward pass for training and inference."""
|
|
@@ -366,7 +354,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 366 |
|
| 367 |
if input_features is not None and input_ids is not None:
|
| 368 |
# Encode audio -> flattened (total_audio_tokens, hidden_dim)
|
| 369 |
-
audio_embeds = self._encode_audio(input_features
|
| 370 |
|
| 371 |
# Replace <audio> token placeholders with audio embeddings using masked_scatter
|
| 372 |
audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
|
|
@@ -427,7 +415,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 427 |
input_ids: Optional[torch.Tensor] = None,
|
| 428 |
input_features: Optional[torch.Tensor] = None,
|
| 429 |
attention_mask: Optional[torch.Tensor] = None,
|
| 430 |
-
audio_attention_mask: Optional[torch.Tensor] = None,
|
| 431 |
system_prompt: Optional[str] = None,
|
| 432 |
**generate_kwargs,
|
| 433 |
) -> torch.Tensor:
|
|
@@ -444,7 +431,7 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 444 |
batch_size = input_features.shape[0]
|
| 445 |
|
| 446 |
# Encode audio -> flattened embeddings
|
| 447 |
-
audio_embeds = self._encode_audio(input_features
|
| 448 |
|
| 449 |
# If input_ids not provided, build prompt with correct number of audio tokens
|
| 450 |
if input_ids is None:
|
|
|
|
| 269 |
"""Only save trainable projector weights."""
|
| 270 |
return {f"projector.{k}": v for k, v in self.projector.state_dict().items()}
|
| 271 |
|
| 272 |
+
def _apply_specaugment(self, input_features: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
if not getattr(self.config, "use_specaugment", False):
|
| 274 |
return input_features
|
| 275 |
|
|
|
|
| 290 |
(batch_size, sequence_length),
|
| 291 |
mask_prob=mask_time_prob,
|
| 292 |
mask_length=mask_time_length,
|
|
|
|
| 293 |
min_masks=2,
|
| 294 |
)
|
| 295 |
mask_time_indices = torch.tensor(
|
|
|
|
| 316 |
|
| 317 |
return input_features
|
| 318 |
|
| 319 |
+
def _encode_audio(self, audio_features: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
"""Encode audio and project to LLM embedding space.
|
| 321 |
|
| 322 |
Returns flattened audio embeddings of shape (total_audio_tokens, hidden_dim).
|
| 323 |
"""
|
| 324 |
# Apply SpecAugment during training (before encoding)
|
| 325 |
+
audio_features = self._apply_specaugment(audio_features)
|
| 326 |
|
| 327 |
with torch.no_grad():
|
| 328 |
+
encoder_out = self.audio_tower(input_features=audio_features)
|
|
|
|
|
|
|
| 329 |
hidden_states = encoder_out.last_hidden_state
|
| 330 |
|
| 331 |
audio_embeds = self.projector(hidden_states)
|
|
|
|
| 345 |
labels: Optional[torch.Tensor] = None,
|
| 346 |
use_cache: Optional[bool] = None,
|
| 347 |
cache_position: Optional[torch.Tensor] = None,
|
|
|
|
| 348 |
**kwargs,
|
| 349 |
) -> CausalLMOutputWithPast:
|
| 350 |
"""Forward pass for training and inference."""
|
|
|
|
| 354 |
|
| 355 |
if input_features is not None and input_ids is not None:
|
| 356 |
# Encode audio -> flattened (total_audio_tokens, hidden_dim)
|
| 357 |
+
audio_embeds = self._encode_audio(input_features)
|
| 358 |
|
| 359 |
# Replace <audio> token placeholders with audio embeddings using masked_scatter
|
| 360 |
audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
|
|
|
|
| 415 |
input_ids: Optional[torch.Tensor] = None,
|
| 416 |
input_features: Optional[torch.Tensor] = None,
|
| 417 |
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
| 418 |
system_prompt: Optional[str] = None,
|
| 419 |
**generate_kwargs,
|
| 420 |
) -> torch.Tensor:
|
|
|
|
| 431 |
batch_size = input_features.shape[0]
|
| 432 |
|
| 433 |
# Encode audio -> flattened embeddings
|
| 434 |
+
audio_embeds = self._encode_audio(input_features)
|
| 435 |
|
| 436 |
# If input_ids not provided, build prompt with correct number of audio tokens
|
| 437 |
if input_ids is None:
|
projectors.py
CHANGED
|
@@ -680,14 +680,13 @@ class QFormerAudioProjector(nn.Module):
|
|
| 680 |
effective_batch = batch_size * nblocks
|
| 681 |
hidden_states = hidden_states.view(effective_batch, self.window_size, -1)
|
| 682 |
|
| 683 |
-
# Expand queries to match batch size
|
| 684 |
query_embeds = self.query.expand(effective_batch, -1, -1)
|
| 685 |
|
| 686 |
# QFormer cross-attention
|
| 687 |
query_output = self.qformer(
|
| 688 |
query_embeds=query_embeds,
|
| 689 |
encoder_hidden_states=hidden_states,
|
| 690 |
-
encoder_attention_mask=None,
|
| 691 |
return_dict=True,
|
| 692 |
)
|
| 693 |
|
|
|
|
| 680 |
effective_batch = batch_size * nblocks
|
| 681 |
hidden_states = hidden_states.view(effective_batch, self.window_size, -1)
|
| 682 |
|
| 683 |
+
# Expand queries to match batch size
|
| 684 |
query_embeds = self.query.expand(effective_batch, -1, -1)
|
| 685 |
|
| 686 |
# QFormer cross-attention
|
| 687 |
query_output = self.qformer(
|
| 688 |
query_embeds=query_embeds,
|
| 689 |
encoder_hidden_states=hidden_states,
|
|
|
|
| 690 |
return_dict=True,
|
| 691 |
)
|
| 692 |
|