mazesmazes commited on
Commit
eb665c7
·
verified ·
1 Parent(s): 948c073

Training in progress - step 500

Browse files
Files changed (1) hide show
  1. projectors.py +11 -19
projectors.py CHANGED
@@ -605,9 +605,9 @@ class TransformerAudioProjector(nn.Module):
605
 
606
  def __init__(self, config):
607
  super().__init__()
608
- # Default stride 4: Whisper (2x) * Projector (4x) = 8x total → ~12.5 Hz
609
- # Similar to FunASR's 6x total (~16.67 Hz)
610
- self.k = getattr(config, "projector_pool_stride", 4)
611
 
612
  encoder_dim = config.encoder_dim
613
  llm_dim = config.llm_dim
@@ -615,8 +615,9 @@ class TransformerAudioProjector(nn.Module):
615
  # Input: Stacked frames (e.g. 1280 * 2 = 2560)
616
  in_dim = encoder_dim * self.k
617
 
618
- # FFN hidden dim for initial projection (FunASR default: 2048)
619
- ffn_dim = getattr(config, "projector_hidden_dim", None) or 2048
 
620
 
621
  # FunASR-style projection: linear1 -> relu -> linear2
622
  self.linear1 = nn.Linear(in_dim, ffn_dim)
@@ -629,27 +630,18 @@ class TransformerAudioProjector(nn.Module):
629
  encoder_layer = nn.TransformerEncoderLayer(
630
  d_model=llm_dim,
631
  nhead=getattr(config, "projector_num_heads", 8),
632
- dim_feedforward=llm_dim // 4, # FunASR uses quarter size
633
  dropout=0.0,
634
  activation="relu",
635
  batch_first=True,
636
  norm_first=True,
637
  )
638
- self.blocks = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
 
 
639
  else:
640
  self.blocks = None
641
 
642
- # Final Norm for stability when projecting to frozen LLM
643
- self.norm = LlamaRMSNorm(llm_dim, eps=1e-8)
644
-
645
- self.apply(self._init_weights)
646
-
647
- def _init_weights(self, m):
648
- if isinstance(m, nn.Linear):
649
- nn.init.trunc_normal_(m.weight, std=0.02)
650
- if m.bias is not None:
651
- nn.init.zeros_(m.bias)
652
-
653
  def forward(self, x):
654
  # x: [Batch, Seq, Dim]
655
  batch, seq, dim = x.shape
@@ -672,7 +664,7 @@ class TransformerAudioProjector(nn.Module):
672
  if self.blocks is not None:
673
  x = self.blocks(x)
674
 
675
- return self.norm(x)
676
 
677
  def get_output_length(self, input_length: int) -> int:
678
  return (input_length - 1) // self.k + 1
 
605
 
606
  def __init__(self, config):
607
  super().__init__()
608
+ # Default stride 6: Whisper (2x) * Projector (6x) = 12x total → ~8 Hz
609
+ # Matches FunASR's total stride (6x encoder * 2x projector = 12x)
610
+ self.k = getattr(config, "projector_pool_stride", 6)
611
 
612
  encoder_dim = config.encoder_dim
613
  llm_dim = config.llm_dim
 
615
  # Input: Stacked frames (e.g. 1280 * 2 = 2560)
616
  in_dim = encoder_dim * self.k
617
 
618
+ # FFN hidden dim for initial projection (balanced compression)
619
+ # 7680 → 4096 → 2048 distributes compression evenly (~2x each layer)
620
+ ffn_dim = getattr(config, "projector_hidden_dim", None) or 4096
621
 
622
  # FunASR-style projection: linear1 -> relu -> linear2
623
  self.linear1 = nn.Linear(in_dim, ffn_dim)
 
630
  encoder_layer = nn.TransformerEncoderLayer(
631
  d_model=llm_dim,
632
  nhead=getattr(config, "projector_num_heads", 8),
633
+ dim_feedforward=1024, # Match FunASR (audio complexity is LLM-independent)
634
  dropout=0.0,
635
  activation="relu",
636
  batch_first=True,
637
  norm_first=True,
638
  )
639
+ self.blocks = nn.TransformerEncoder(
640
+ encoder_layer, num_layers=num_layers, enable_nested_tensor=False
641
+ )
642
  else:
643
  self.blocks = None
644
 
 
 
 
 
 
 
 
 
 
 
 
645
  def forward(self, x):
646
  # x: [Batch, Seq, Dim]
647
  batch, seq, dim = x.shape
 
664
  if self.blocks is not None:
665
  x = self.blocks(x)
666
 
667
+ return x
668
 
669
  def get_output_length(self, input_length: int) -> int:
670
  return (input_length - 1) // self.k + 1