mazesmazes commited on
Commit
158b30e
·
verified ·
1 Parent(s): 29e583c

Training in progress - step 500

Browse files
Files changed (2) hide show
  1. model.safetensors +2 -2
  2. projectors.py +101 -133
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a1fbff9f5ed1bd4099f58cd2e093052409854375c4cde75c051e28ad58ae1122
3
- size 124082792
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:720402f0e8107015c77907789cc8b20307741b0008412ad5e2cffc08462ae5c9
3
+ size 265642600
projectors.py CHANGED
@@ -237,8 +237,8 @@ class SwiGLU(nn.Module):
237
 
238
  class SwiGLUAudioProjector(nn.Module):
239
  """
240
- Optimized for Frozen LLM + 2500h Data.
241
- Target: 12.5 Hz Output (Stride 4) with 8/3 SwiGLU Expansion.
242
  """
243
 
244
  def __init__(self, config):
@@ -247,154 +247,44 @@ class SwiGLUAudioProjector(nn.Module):
247
  encoder_dim = config.encoder_dim
248
  llm_dim = config.llm_dim
249
 
250
- # Conv Expansion (Compensating for Time Compression)
251
- # We compress time by 4x, so we expand width by 2x to preserve info density.
252
- hidden_dim = int(encoder_dim * 2)
253
 
254
- # SwiGLU Internal Expansion (The 8/3 Ratio)
255
- # To match standard FFN capacity: 4 * (2/3) = 8/3
256
- swiglu_inner = int(hidden_dim * 8 / 3)
257
 
258
- self.downsample = nn.Conv1d(
259
- in_channels=encoder_dim,
260
- out_channels=hidden_dim,
261
- kernel_size=self.k,
262
- stride=self.k,
263
- padding=0,
264
- )
265
 
 
266
  self.norm = LlamaRMSNorm(hidden_dim, eps=1e-8)
267
 
 
 
268
  self.proj = SwiGLU(hidden_dim, swiglu_inner, llm_dim)
269
 
270
- self.apply(self._init_weights)
271
-
272
- def _init_weights(self, m):
273
- if isinstance(m, (nn.Linear, nn.Conv1d)):
274
- nn.init.trunc_normal_(m.weight, std=0.02)
275
- if m.bias is not None:
276
- nn.init.constant_(m.bias, 0)
277
-
278
  def forward(self, x):
279
  # x: [Batch, Seq, Dim]
280
  batch, seq, dim = x.shape
281
 
282
- # Manual Padding (prevents frame dropping)
283
- if seq % self.k != 0:
284
- pad_len = self.k - (seq % self.k)
285
- x = F.pad(x, (0, 0, 0, pad_len))
 
286
 
287
- # [B, S, D] -> [B, D, S]
288
- x = x.transpose(1, 2)
289
 
290
- # Downsample (50Hz -> 12.5Hz)
291
- x = self.downsample(x)
292
-
293
- # [B, D, S] -> [B, S, D]
294
- x = x.transpose(1, 2)
295
 
296
- # Norm & Project
297
  x = self.norm(x)
298
  return self.proj(x)
299
 
300
  def get_output_length(self, input_length: int) -> int:
301
- return (input_length + self.k - 1) // self.k
302
-
303
- # =============================================================================
304
- # Residual Projector
305
- # =============================================================================
306
-
307
-
308
- class ResidualMLP(nn.Module):
309
- """MLP block with residual connection: Output = x + MLP(x)."""
310
-
311
- def __init__(self, dim, hidden_dim):
312
- super().__init__()
313
- self.fc1 = nn.Linear(dim, hidden_dim)
314
- self.fc2 = nn.Linear(hidden_dim, dim)
315
- self.act = nn.GELU()
316
-
317
- def forward(self, x):
318
- residual = x
319
- x = self.fc1(x)
320
- x = self.act(x)
321
- x = self.fc2(x)
322
- return residual + x
323
-
324
-
325
- class ResidualAudioProjector(nn.Module):
326
- """Residual MLP projector for audio-to-LLM feature translation."""
327
-
328
- def __init__(self, config):
329
- super().__init__()
330
-
331
- self.k = getattr(config, "projector_pool_stride", 4)
332
- in_dim = config.encoder_dim * self.k
333
- out_dim = config.llm_dim
334
- hidden_dim = getattr(config, "projector_hidden_dim", None) or out_dim * 4
335
- self.num_layers = getattr(config, "projector_num_layers", 2)
336
-
337
- self.input_proj = nn.Linear(in_dim, out_dim)
338
- self.ln_input = LlamaRMSNorm(out_dim, eps=1e-8)
339
-
340
- self.layers = nn.ModuleList(
341
- [ResidualMLP(out_dim, hidden_dim) for _ in range(self.num_layers)]
342
- )
343
- self.layer_norms = nn.ModuleList(
344
- [LlamaRMSNorm(out_dim, eps=1e-8) for _ in range(self.num_layers)]
345
- )
346
-
347
- self._init_weights(config)
348
-
349
- def _init_weights(self, config):
350
- std = getattr(config, "projector_init_std", 0.02)
351
-
352
- with torch.no_grad():
353
- nn.init.normal_(self.input_proj.weight, mean=0.0, std=std)
354
- if self.input_proj.bias is not None:
355
- nn.init.zeros_(self.input_proj.bias)
356
-
357
- self.ln_input.weight.data.fill_(1.0)
358
- for ln in self.layer_norms:
359
- ln.weight.data.fill_(1.0)
360
-
361
- for layer in self.layers:
362
- nn.init.normal_(layer.fc1.weight, mean=0.0, std=std)
363
- nn.init.normal_(layer.fc2.weight, mean=0.0, std=std * 0.1)
364
- if layer.fc1.bias is not None:
365
- nn.init.zeros_(layer.fc1.bias)
366
- if layer.fc2.bias is not None:
367
- nn.init.zeros_(layer.fc2.bias)
368
-
369
- def get_output_length(self, input_length: int) -> int:
370
- """Calculate output sequence length given input length."""
371
- # Temporal pooling with stride k
372
- remainder = input_length % self.k
373
- if remainder:
374
- input_length += self.k - remainder
375
- return input_length // self.k
376
-
377
- def forward(self, x):
378
- batch_size, seq_len, dim = x.size()
379
-
380
- target_dtype = self.input_proj.weight.dtype
381
- if x.dtype != target_dtype:
382
- x = x.to(target_dtype)
383
-
384
- remainder = seq_len % self.k
385
- if remainder:
386
- pad_len = self.k - remainder
387
- x = F.pad(x, (0, 0, 0, pad_len))
388
-
389
- x = x.contiguous().view(batch_size, -1, dim * self.k)
390
- x = self.input_proj(x)
391
- x = self.ln_input(x)
392
-
393
- for layer, ln in zip(self.layers, self.layer_norms):
394
- x = layer(x)
395
- x = ln(x)
396
-
397
- return x
398
 
399
 
400
  # =============================================================================
@@ -688,6 +578,84 @@ class QFormerAudioProjector(nn.Module):
688
  return self.linear(query_proj)
689
 
690
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
691
  # =============================================================================
692
  # Projector Registry
693
  # =============================================================================
@@ -696,7 +664,7 @@ PROJECTOR_CLASSES = {
696
  "mlp": MLPAudioProjector,
697
  "mosa": MOSAProjector,
698
  "swiglu": SwiGLUAudioProjector,
699
- "residual": ResidualAudioProjector,
700
  "shared_moe": SharedMoEAudioProjector,
701
  "qformer": QFormerAudioProjector,
 
702
  }
 
237
 
238
  class SwiGLUAudioProjector(nn.Module):
239
  """
240
+ SwiGLU projector with frame stacking (FunASR-style).
241
+ Uses frame stacking for downsampling, linear projection, then SwiGLU.
242
  """
243
 
244
  def __init__(self, config):
 
247
  encoder_dim = config.encoder_dim
248
  llm_dim = config.llm_dim
249
 
250
+ # Frame stacking input dimension
251
+ in_dim = encoder_dim * self.k # 1280 * 4 = 5120
 
252
 
253
+ # Hidden dim after initial projection (balanced compression like transformer)
254
+ hidden_dim = getattr(config, "projector_hidden_dim", None) or 4096
 
255
 
256
+ # Initial linear projection (frame stacking → hidden)
257
+ self.linear = nn.Linear(in_dim, hidden_dim)
 
 
 
 
 
258
 
259
+ # Norm before SwiGLU
260
  self.norm = LlamaRMSNorm(hidden_dim, eps=1e-8)
261
 
262
+ # SwiGLU with 8/3 expansion ratio
263
+ swiglu_inner = int(hidden_dim * 8 / 3)
264
  self.proj = SwiGLU(hidden_dim, swiglu_inner, llm_dim)
265
 
 
 
 
 
 
 
 
 
266
  def forward(self, x):
267
  # x: [Batch, Seq, Dim]
268
  batch, seq, dim = x.shape
269
 
270
+ # Padding to multiple of k
271
+ chunk_num = (seq - 1) // self.k + 1
272
+ pad_num = chunk_num * self.k - seq
273
+ if pad_num > 0:
274
+ x = F.pad(x, (0, 0, 0, pad_num))
275
 
276
+ # Frame stacking: [B, S, D] -> [B, S/k, D*k]
277
+ x = x.contiguous().view(batch, chunk_num, dim * self.k)
278
 
279
+ # Linear projection
280
+ x = self.linear(x)
 
 
 
281
 
282
+ # Norm & SwiGLU
283
  x = self.norm(x)
284
  return self.proj(x)
285
 
286
  def get_output_length(self, input_length: int) -> int:
287
+ return (input_length - 1) // self.k + 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
 
290
  # =============================================================================
 
578
  return self.linear(query_proj)
579
 
580
 
581
+ # =============================================================================
582
+ # Transformer Projector
583
+ # =============================================================================
584
+
585
+
586
+ class TransformerAudioProjector(nn.Module):
587
+ """
588
+ Transformer Projector (FunASR Style).
589
+ Projects to LLM dim first, then applies transformer blocks for context mixing.
590
+ """
591
+
592
+ def __init__(self, config):
593
+ super().__init__()
594
+ # Default stride 6: Whisper (2x) * Projector (6x) = 12x total → ~8 Hz
595
+ # Matches FunASR's total stride (6x encoder * 2x projector = 12x)
596
+ self.k = getattr(config, "projector_pool_stride", 6)
597
+
598
+ encoder_dim = config.encoder_dim
599
+ llm_dim = config.llm_dim
600
+
601
+ # Input: Stacked frames (e.g. 1280 * 2 = 2560)
602
+ in_dim = encoder_dim * self.k
603
+
604
+ # FFN hidden dim for initial projection (balanced compression)
605
+ # 7680 → 4096 → 2048 distributes compression evenly (~2x each layer)
606
+ ffn_dim = getattr(config, "projector_hidden_dim", None) or 4096
607
+
608
+ # FunASR-style projection: linear1 -> relu -> linear2
609
+ self.linear1 = nn.Linear(in_dim, ffn_dim)
610
+ self.relu = nn.ReLU()
611
+ self.linear2 = nn.Linear(ffn_dim, llm_dim)
612
+
613
+ # Transformer blocks operating at llm_dim
614
+ num_layers = getattr(config, "projector_num_layers", 2)
615
+ if num_layers > 0:
616
+ encoder_layer = nn.TransformerEncoderLayer(
617
+ d_model=llm_dim,
618
+ nhead=getattr(config, "projector_num_heads", 8),
619
+ dim_feedforward=1024, # Match FunASR (audio complexity is LLM-independent)
620
+ dropout=0.0,
621
+ activation="relu",
622
+ batch_first=True,
623
+ norm_first=True,
624
+ )
625
+ self.blocks = nn.TransformerEncoder(
626
+ encoder_layer, num_layers=num_layers, enable_nested_tensor=False
627
+ )
628
+ else:
629
+ self.blocks = None
630
+
631
+ def forward(self, x):
632
+ # x: [Batch, Seq, Dim]
633
+ batch, seq, dim = x.shape
634
+
635
+ # Padding to multiple of k
636
+ chunk_num = (seq - 1) // self.k + 1
637
+ pad_num = chunk_num * self.k - seq
638
+ if pad_num > 0:
639
+ x = F.pad(x, (0, 0, 0, pad_num))
640
+
641
+ # Frame stacking: [B, S, D] -> [B, S/k, D*k]
642
+ x = x.contiguous().view(batch, chunk_num, dim * self.k)
643
+
644
+ # FunASR-style projection to LLM dim
645
+ x = self.linear1(x)
646
+ x = self.relu(x)
647
+ x = self.linear2(x)
648
+
649
+ # Transformer context mixing
650
+ if self.blocks is not None:
651
+ x = self.blocks(x)
652
+
653
+ return x
654
+
655
+ def get_output_length(self, input_length: int) -> int:
656
+ return (input_length - 1) // self.k + 1
657
+
658
+
659
  # =============================================================================
660
  # Projector Registry
661
  # =============================================================================
 
664
  "mlp": MLPAudioProjector,
665
  "mosa": MOSAProjector,
666
  "swiglu": SwiGLUAudioProjector,
 
667
  "shared_moe": SharedMoEAudioProjector,
668
  "qformer": QFormerAudioProjector,
669
+ "transformer": TransformerAudioProjector,
670
  }