mazesmazes commited on
Commit
94b6d6d
·
verified ·
1 Parent(s): 968a208

Training in progress - step 500

Browse files
Files changed (4) hide show
  1. asr_config.py +0 -2
  2. asr_modeling.py +0 -53
  3. model.safetensors +2 -2
  4. projectors.py +44 -66
asr_config.py CHANGED
@@ -30,7 +30,6 @@ class ASRConfig(transformers.PretrainedConfig):
30
  num_experts: int = 4, # Number of experts in MoE projectors
31
  num_experts_per_tok: int = 2, # Top-k experts per token
32
  router_aux_loss_coef: float = 0.01, # Auxiliary loss coefficient for load balancing
33
- use_specaugment: bool = True, # Apply SpecAugment during training
34
  # QFormer-specific configuration (Granite defaults)
35
  qformer_window_size: int = 15, # Window size for QFormer processing
36
  qformer_hidden_size: Optional[int] = None, # QFormer hidden size (defaults to encoder_dim)
@@ -79,7 +78,6 @@ class ASRConfig(transformers.PretrainedConfig):
79
  self.num_experts = num_experts
80
  self.num_experts_per_tok = num_experts_per_tok
81
  self.router_aux_loss_coef = router_aux_loss_coef
82
- self.use_specaugment = use_specaugment
83
  # QFormer-specific configuration
84
  self.qformer_window_size = qformer_window_size
85
  self.qformer_hidden_size = qformer_hidden_size
 
30
  num_experts: int = 4, # Number of experts in MoE projectors
31
  num_experts_per_tok: int = 2, # Top-k experts per token
32
  router_aux_loss_coef: float = 0.01, # Auxiliary loss coefficient for load balancing
 
33
  # QFormer-specific configuration (Granite defaults)
34
  qformer_window_size: int = 15, # Window size for QFormer processing
35
  qformer_hidden_size: Optional[int] = None, # QFormer hidden size (defaults to encoder_dim)
 
78
  self.num_experts = num_experts
79
  self.num_experts_per_tok = num_experts_per_tok
80
  self.router_aux_loss_coef = router_aux_loss_coef
 
81
  # QFormer-specific configuration
82
  self.qformer_window_size = qformer_window_size
83
  self.qformer_hidden_size = qformer_hidden_size
asr_modeling.py CHANGED
@@ -13,9 +13,6 @@ from transformers import (
13
  )
14
  from transformers.generation import GenerationMixin
15
  from transformers.modeling_outputs import CausalLMOutputWithPast
16
- from transformers.models.whisper.modeling_whisper import (
17
- _compute_mask_indices,
18
- )
19
 
20
  try:
21
  from .asr_config import ASRConfig
@@ -269,53 +266,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
269
  """Only save trainable projector weights."""
270
  return {f"projector.{k}": v for k, v in self.projector.state_dict().items()}
271
 
272
- def _apply_specaugment(self, input_features: torch.Tensor) -> torch.Tensor:
273
- if not getattr(self.config, "use_specaugment", False):
274
- return input_features
275
-
276
- if not self.training:
277
- return input_features
278
-
279
- # Input shape: (batch_size, num_mel_bins, sequence_length) for Whisper
280
- batch_size, hidden_size, sequence_length = input_features.size()
281
-
282
- mask_time_prob = getattr(self.config, "mask_time_prob", 0.05)
283
- mask_time_length = getattr(self.config, "mask_time_length", 10)
284
- mask_feature_prob = getattr(self.config, "mask_feature_prob", 0.0)
285
- mask_feature_length = getattr(self.config, "mask_feature_length", 10)
286
-
287
- # Time masking
288
- if mask_time_prob > 0:
289
- mask_time_np = _compute_mask_indices(
290
- (batch_size, sequence_length),
291
- mask_prob=mask_time_prob,
292
- mask_length=mask_time_length,
293
- min_masks=2,
294
- )
295
- mask_time_indices = torch.tensor(
296
- mask_time_np, device=input_features.device, dtype=torch.bool
297
- )
298
- # Expand to cover all features: (batch, seq) -> (batch, features, seq)
299
- mask_time_expanded = mask_time_indices[:, None].expand(-1, hidden_size, -1)
300
- input_features = input_features.masked_fill(mask_time_expanded, 0.0)
301
-
302
- # Feature masking
303
- if mask_feature_prob > 0:
304
- mask_feature_np = _compute_mask_indices(
305
- (batch_size, hidden_size),
306
- mask_prob=mask_feature_prob,
307
- mask_length=mask_feature_length,
308
- min_masks=2,
309
- )
310
- mask_feature_indices = torch.tensor(
311
- mask_feature_np, device=input_features.device, dtype=torch.bool
312
- )
313
- # Expand: (batch, features) -> (batch, features, seq)
314
- mask_feature_expanded = mask_feature_indices[:, :, None].expand(-1, -1, sequence_length)
315
- input_features = input_features.masked_fill(mask_feature_expanded, 0.0)
316
-
317
- return input_features
318
-
319
  def _encode_audio(
320
  self,
321
  audio_features: torch.Tensor,
@@ -330,9 +280,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
330
  Returns:
331
  Flattened audio embeddings of shape (total_audio_tokens, hidden_dim).
332
  """
333
- # Apply SpecAugment during training (before encoding)
334
- audio_features = self._apply_specaugment(audio_features)
335
-
336
  with torch.no_grad():
337
  encoder_out = self.audio_tower(input_features=audio_features)
338
  hidden_states = encoder_out.last_hidden_state
 
13
  )
14
  from transformers.generation import GenerationMixin
15
  from transformers.modeling_outputs import CausalLMOutputWithPast
 
 
 
16
 
17
  try:
18
  from .asr_config import ASRConfig
 
266
  """Only save trainable projector weights."""
267
  return {f"projector.{k}": v for k, v in self.projector.state_dict().items()}
268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  def _encode_audio(
270
  self,
271
  audio_features: torch.Tensor,
 
280
  Returns:
281
  Flattened audio embeddings of shape (total_audio_tokens, hidden_dim).
282
  """
 
 
 
283
  with torch.no_grad():
284
  encoder_out = self.audio_tower(input_features=audio_features)
285
  hidden_states = encoder_out.last_hidden_state
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:141150eea397b1108cb83d033075f10dfdce3d3d652d48d647d34cf7b804bb50
3
- size 721963128
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:82969523f91eb82b594c3669afa1c1cc8d1c49d5e3414997e659c8214b8f5942
3
+ size 124082792
projectors.py CHANGED
@@ -222,28 +222,23 @@ class MOSAProjector(nn.Module):
222
 
223
 
224
  class SwiGLU(nn.Module):
225
- def __init__(self, in_features, hidden_features, out_features, bias=False, dropout=0.0):
 
 
226
  super().__init__()
227
- self.w1 = nn.Linear(in_features, hidden_features, bias=bias)
228
- self.w2 = nn.Linear(in_features, hidden_features, bias=bias)
229
- self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
230
  self.act = nn.SiLU()
231
- self.dropout = nn.Dropout(dropout)
232
 
233
  def forward(self, x):
234
- x_gate = self.act(self.w1(x))
235
- x_val = self.w2(x)
236
- x = x_gate * x_val
237
- x = self.dropout(x)
238
- return self.w3(x)
239
 
240
 
241
  class SwiGLUAudioProjector(nn.Module):
242
  """
243
- SwiGLU projector with:
244
- 1. C-Abstractor style dual-path context (Conv + AvgPool).
245
- 2. Llama 3 style hidden dimension calculation.
246
- 3. RMSNorm for training stability.
247
  """
248
 
249
  def __init__(self, config):
@@ -252,69 +247,58 @@ class SwiGLUAudioProjector(nn.Module):
252
  encoder_dim = config.encoder_dim
253
  llm_dim = config.llm_dim
254
 
255
- # 1. C-Abstractor Style Dual-Path Context
256
- # Path A: Depthwise Conv (Phonetic features)
257
- self.local_context = nn.Conv1d(
258
- encoder_dim, encoder_dim, kernel_size=3, padding=1, groups=encoder_dim, bias=False
 
 
 
 
 
 
 
 
 
 
259
  )
260
- # Path B: Mean Pooling (Energy/Prosody features)
261
- # We use a kernel of 3 to match the Conv1d's receptive field
262
- self.energy_pool = nn.AvgPool1d(kernel_size=3, stride=1, padding=1)
263
 
264
- # 2. Llama 3 Style Dimension Calculation
265
- d_model = encoder_dim * self.k
266
- hidden_dim = int(2 * (d_model * 4) / 3)
267
- multiple_of = 256
268
- hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
269
 
270
- # 3. Normalization and SwiGLU
271
- self.pre_norm = LlamaRMSNorm(d_model, eps=1e-8)
272
- self.proj1 = SwiGLU(d_model, hidden_dim, hidden_dim)
273
- self.proj2 = nn.Linear(hidden_dim, llm_dim, bias=False)
274
 
275
  self.apply(self._init_weights)
276
 
277
  def _init_weights(self, m):
278
- if isinstance(m, nn.Linear):
279
  nn.init.trunc_normal_(m.weight, std=0.02)
280
  if m.bias is not None:
281
- nn.init.zeros_(m.bias)
282
 
283
  def forward(self, x):
284
  # x: [Batch, Seq, Dim]
285
  batch, seq, dim = x.shape
286
 
287
- # --- Dual-Path Context Injection ---
288
- x_trans = x.transpose(1, 2) # [B, D, S]
289
-
290
- # Branch A: Convolutional Detail
291
- x_conv = self.local_context(x_trans)
292
 
293
- # Branch B: Energy Abstraction
294
- x_energy = self.energy_pool(x_trans)
295
 
296
- # Combine and Add Residual (Summing the branches)
297
- # This gives the model a multi-resolution view of the audio
298
- x = (x_conv + x_energy).transpose(1, 2) + x
299
 
300
- # --- Frame Concatenation ---
301
- if seq % self.k:
302
- x = F.pad(x, (0, 0, 0, self.k - (seq % self.k)))
303
- x = x.reshape(batch, -1, dim * self.k)
304
 
305
- # --- Projection ---
306
- x = self.pre_norm(x)
307
- x = self.proj1(x)
308
- return self.proj2(x)
309
 
310
  def get_output_length(self, input_length: int) -> int:
311
- remainder = input_length % self.k
312
- return (input_length + self.k - 1) // self.k if remainder else input_length // self.k
313
-
314
-
315
- # Alias for backwards compatibility
316
- AudioProjector = SwiGLUAudioProjector
317
-
318
 
319
  # =============================================================================
320
  # Residual Projector
@@ -324,20 +308,17 @@ AudioProjector = SwiGLUAudioProjector
324
  class ResidualMLP(nn.Module):
325
  """MLP block with residual connection: Output = x + MLP(x)."""
326
 
327
- def __init__(self, dim, hidden_dim, dropout=0.0):
328
  super().__init__()
329
  self.fc1 = nn.Linear(dim, hidden_dim)
330
  self.fc2 = nn.Linear(hidden_dim, dim)
331
  self.act = nn.GELU()
332
- self.dropout = nn.Dropout(dropout)
333
 
334
  def forward(self, x):
335
  residual = x
336
  x = self.fc1(x)
337
  x = self.act(x)
338
- x = self.dropout(x)
339
  x = self.fc2(x)
340
- x = self.dropout(x)
341
  return residual + x
342
 
343
 
@@ -352,19 +333,17 @@ class ResidualAudioProjector(nn.Module):
352
  out_dim = config.llm_dim
353
  hidden_dim = getattr(config, "projector_hidden_dim", None) or out_dim * 4
354
  self.num_layers = getattr(config, "projector_num_layers", 2)
355
- dropout_rate = getattr(config, "projector_dropout", 0.0)
356
 
357
  self.input_proj = nn.Linear(in_dim, out_dim)
358
  self.ln_input = LlamaRMSNorm(out_dim, eps=1e-8)
359
 
360
  self.layers = nn.ModuleList(
361
- [ResidualMLP(out_dim, hidden_dim, dropout=dropout_rate) for _ in range(self.num_layers)]
362
  )
363
  self.layer_norms = nn.ModuleList(
364
  [LlamaRMSNorm(out_dim, eps=1e-8) for _ in range(self.num_layers)]
365
  )
366
 
367
- self.output_dropout = nn.Dropout(dropout_rate)
368
  self._init_weights(config)
369
 
370
  def _init_weights(self, config):
@@ -415,7 +394,7 @@ class ResidualAudioProjector(nn.Module):
415
  x = layer(x)
416
  x = ln(x)
417
 
418
- return self.output_dropout(x)
419
 
420
 
421
  # =============================================================================
@@ -526,7 +505,6 @@ class SharedMoEAudioProjector(nn.Module):
526
  def __init__(self, config):
527
  super().__init__()
528
 
529
- # Default stride is now 2 (was 4)
530
  self.k = getattr(config, "projector_pool_stride", 4)
531
  encoder_dim = config.encoder_dim
532
 
 
222
 
223
 
224
  class SwiGLU(nn.Module):
225
+ """SwiGLU activation block (Llama-style: SiLU(Gate) * Value -> Output)."""
226
+
227
+ def __init__(self, in_features, hidden_features, out_features):
228
  super().__init__()
229
+ self.w1 = nn.Linear(in_features, hidden_features, bias=False) # Gate
230
+ self.w2 = nn.Linear(in_features, hidden_features, bias=False) # Value
231
+ self.w3 = nn.Linear(hidden_features, out_features, bias=False) # Output
232
  self.act = nn.SiLU()
 
233
 
234
  def forward(self, x):
235
+ return self.w3(self.act(self.w1(x)) * self.w2(x))
 
 
 
 
236
 
237
 
238
  class SwiGLUAudioProjector(nn.Module):
239
  """
240
+ Optimized for Frozen LLM + 2500h Data.
241
+ Target: 12.5 Hz Output (Stride 4) with 8/3 SwiGLU Expansion.
 
 
242
  """
243
 
244
  def __init__(self, config):
 
247
  encoder_dim = config.encoder_dim
248
  llm_dim = config.llm_dim
249
 
250
+ # Conv Expansion (Compensating for Time Compression)
251
+ # We compress time by 4x, so we expand width by 2x to preserve info density.
252
+ hidden_dim = int(encoder_dim * 2)
253
+
254
+ # SwiGLU Internal Expansion (The 8/3 Ratio)
255
+ # To match standard FFN capacity: 4 * (2/3) = 8/3
256
+ swiglu_inner = int(hidden_dim * 8 / 3)
257
+
258
+ self.downsample = nn.Conv1d(
259
+ in_channels=encoder_dim,
260
+ out_channels=hidden_dim,
261
+ kernel_size=self.k,
262
+ stride=self.k,
263
+ padding=0,
264
  )
 
 
 
265
 
266
+ self.norm = LlamaRMSNorm(hidden_dim, eps=1e-8)
 
 
 
 
267
 
268
+ self.proj = SwiGLU(hidden_dim, swiglu_inner, llm_dim)
 
 
 
269
 
270
  self.apply(self._init_weights)
271
 
272
  def _init_weights(self, m):
273
+ if isinstance(m, (nn.Linear, nn.Conv1d)):
274
  nn.init.trunc_normal_(m.weight, std=0.02)
275
  if m.bias is not None:
276
+ nn.init.constant_(m.bias, 0)
277
 
278
  def forward(self, x):
279
  # x: [Batch, Seq, Dim]
280
  batch, seq, dim = x.shape
281
 
282
+ # Manual Padding (prevents frame dropping)
283
+ if seq % self.k != 0:
284
+ pad_len = self.k - (seq % self.k)
285
+ x = F.pad(x, (0, 0, 0, pad_len))
 
286
 
287
+ # [B, S, D] -> [B, D, S]
288
+ x = x.transpose(1, 2)
289
 
290
+ # Downsample (50Hz -> 12.5Hz)
291
+ x = self.downsample(x)
 
292
 
293
+ # [B, D, S] -> [B, S, D]
294
+ x = x.transpose(1, 2)
 
 
295
 
296
+ # Norm & Project
297
+ x = self.norm(x)
298
+ return self.proj(x)
 
299
 
300
  def get_output_length(self, input_length: int) -> int:
301
+ return (input_length + self.k - 1) // self.k
 
 
 
 
 
 
302
 
303
  # =============================================================================
304
  # Residual Projector
 
308
  class ResidualMLP(nn.Module):
309
  """MLP block with residual connection: Output = x + MLP(x)."""
310
 
311
+ def __init__(self, dim, hidden_dim):
312
  super().__init__()
313
  self.fc1 = nn.Linear(dim, hidden_dim)
314
  self.fc2 = nn.Linear(hidden_dim, dim)
315
  self.act = nn.GELU()
 
316
 
317
  def forward(self, x):
318
  residual = x
319
  x = self.fc1(x)
320
  x = self.act(x)
 
321
  x = self.fc2(x)
 
322
  return residual + x
323
 
324
 
 
333
  out_dim = config.llm_dim
334
  hidden_dim = getattr(config, "projector_hidden_dim", None) or out_dim * 4
335
  self.num_layers = getattr(config, "projector_num_layers", 2)
 
336
 
337
  self.input_proj = nn.Linear(in_dim, out_dim)
338
  self.ln_input = LlamaRMSNorm(out_dim, eps=1e-8)
339
 
340
  self.layers = nn.ModuleList(
341
+ [ResidualMLP(out_dim, hidden_dim) for _ in range(self.num_layers)]
342
  )
343
  self.layer_norms = nn.ModuleList(
344
  [LlamaRMSNorm(out_dim, eps=1e-8) for _ in range(self.num_layers)]
345
  )
346
 
 
347
  self._init_weights(config)
348
 
349
  def _init_weights(self, config):
 
394
  x = layer(x)
395
  x = ln(x)
396
 
397
+ return x
398
 
399
 
400
  # =============================================================================
 
505
  def __init__(self, config):
506
  super().__init__()
507
 
 
508
  self.k = getattr(config, "projector_pool_stride", 4)
509
  encoder_dim = config.encoder_dim
510