mazesmazes commited on
Commit
78f3906
·
verified ·
1 Parent(s): 83fef99

Training in progress - step 500

Browse files
Files changed (1) hide show
  1. projectors.py +5 -3
projectors.py CHANGED
@@ -21,7 +21,7 @@ from transformers.models.llama.modeling_llama import LlamaRMSNorm
21
 
22
 
23
  class MLPAudioProjector(nn.Module):
24
- """2-layer MLP projector with frame-stacking downsampling (like GLM-ASR)."""
25
 
26
  def __init__(self, config):
27
  super().__init__()
@@ -31,10 +31,12 @@ class MLPAudioProjector(nn.Module):
31
  self.k = getattr(config, "projector_pool_stride", 4)
32
 
33
  # Frame stacking: concat k adjacent frames then project
 
34
  in_dim = encoder_dim * self.k
35
- self.linear_1 = nn.Linear(in_dim, llm_dim, bias=False)
 
36
  self.act = nn.GELU()
37
- self.linear_2 = nn.Linear(llm_dim, llm_dim, bias=False)
38
 
39
  def get_output_length(self, input_length: int) -> int:
40
  """Calculate output sequence length given input length."""
 
21
 
22
 
23
  class MLPAudioProjector(nn.Module):
24
+ """2-layer MLP projector with frame-stacking downsampling (matches GLM-ASR)."""
25
 
26
  def __init__(self, config):
27
  super().__init__()
 
31
  self.k = getattr(config, "projector_pool_stride", 4)
32
 
33
  # Frame stacking: concat k adjacent frames then project
34
+ # Matches GLM-ASR: in_dim -> 2*llm_dim -> llm_dim
35
  in_dim = encoder_dim * self.k
36
+ hidden_dim = llm_dim * 2
37
+ self.linear_1 = nn.Linear(in_dim, hidden_dim)
38
  self.act = nn.GELU()
39
+ self.linear_2 = nn.Linear(hidden_dim, llm_dim)
40
 
41
  def get_output_length(self, input_length: int) -> int:
42
  """Calculate output sequence length given input length."""