Training in progress - step 500
Browse files- projectors.py +5 -3
projectors.py
CHANGED
|
@@ -21,7 +21,7 @@ from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
|
| 21 |
|
| 22 |
|
| 23 |
class MLPAudioProjector(nn.Module):
|
| 24 |
-
"""2-layer MLP projector with frame-stacking downsampling (
|
| 25 |
|
| 26 |
def __init__(self, config):
|
| 27 |
super().__init__()
|
|
@@ -31,10 +31,12 @@ class MLPAudioProjector(nn.Module):
|
|
| 31 |
self.k = getattr(config, "projector_pool_stride", 4)
|
| 32 |
|
| 33 |
# Frame stacking: concat k adjacent frames then project
|
|
|
|
| 34 |
in_dim = encoder_dim * self.k
|
| 35 |
-
|
|
|
|
| 36 |
self.act = nn.GELU()
|
| 37 |
-
self.linear_2 = nn.Linear(
|
| 38 |
|
| 39 |
def get_output_length(self, input_length: int) -> int:
|
| 40 |
"""Calculate output sequence length given input length."""
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
class MLPAudioProjector(nn.Module):
|
| 24 |
+
"""2-layer MLP projector with frame-stacking downsampling (matches GLM-ASR)."""
|
| 25 |
|
| 26 |
def __init__(self, config):
|
| 27 |
super().__init__()
|
|
|
|
| 31 |
self.k = getattr(config, "projector_pool_stride", 4)
|
| 32 |
|
| 33 |
# Frame stacking: concat k adjacent frames then project
|
| 34 |
+
# Matches GLM-ASR: in_dim -> 2*llm_dim -> llm_dim
|
| 35 |
in_dim = encoder_dim * self.k
|
| 36 |
+
hidden_dim = llm_dim * 2
|
| 37 |
+
self.linear_1 = nn.Linear(in_dim, hidden_dim)
|
| 38 |
self.act = nn.GELU()
|
| 39 |
+
self.linear_2 = nn.Linear(hidden_dim, llm_dim)
|
| 40 |
|
| 41 |
def get_output_length(self, input_length: int) -> int:
|
| 42 |
"""Calculate output sequence length given input length."""
|