Training in progress - step 500
Browse files- projectors.py +11 -19
projectors.py
CHANGED
|
@@ -605,9 +605,9 @@ class TransformerAudioProjector(nn.Module):
|
|
| 605 |
|
| 606 |
def __init__(self, config):
|
| 607 |
super().__init__()
|
| 608 |
-
# Default stride
|
| 609 |
-
#
|
| 610 |
-
self.k = getattr(config, "projector_pool_stride",
|
| 611 |
|
| 612 |
encoder_dim = config.encoder_dim
|
| 613 |
llm_dim = config.llm_dim
|
|
@@ -615,8 +615,9 @@ class TransformerAudioProjector(nn.Module):
|
|
| 615 |
# Input: Stacked frames (e.g. 1280 * 2 = 2560)
|
| 616 |
in_dim = encoder_dim * self.k
|
| 617 |
|
| 618 |
-
# FFN hidden dim for initial projection (
|
| 619 |
-
|
|
|
|
| 620 |
|
| 621 |
# FunASR-style projection: linear1 -> relu -> linear2
|
| 622 |
self.linear1 = nn.Linear(in_dim, ffn_dim)
|
|
@@ -629,27 +630,18 @@ class TransformerAudioProjector(nn.Module):
|
|
| 629 |
encoder_layer = nn.TransformerEncoderLayer(
|
| 630 |
d_model=llm_dim,
|
| 631 |
nhead=getattr(config, "projector_num_heads", 8),
|
| 632 |
-
dim_feedforward=
|
| 633 |
dropout=0.0,
|
| 634 |
activation="relu",
|
| 635 |
batch_first=True,
|
| 636 |
norm_first=True,
|
| 637 |
)
|
| 638 |
-
self.blocks = nn.TransformerEncoder(
|
|
|
|
|
|
|
| 639 |
else:
|
| 640 |
self.blocks = None
|
| 641 |
|
| 642 |
-
# Final Norm for stability when projecting to frozen LLM
|
| 643 |
-
self.norm = LlamaRMSNorm(llm_dim, eps=1e-8)
|
| 644 |
-
|
| 645 |
-
self.apply(self._init_weights)
|
| 646 |
-
|
| 647 |
-
def _init_weights(self, m):
|
| 648 |
-
if isinstance(m, nn.Linear):
|
| 649 |
-
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 650 |
-
if m.bias is not None:
|
| 651 |
-
nn.init.zeros_(m.bias)
|
| 652 |
-
|
| 653 |
def forward(self, x):
|
| 654 |
# x: [Batch, Seq, Dim]
|
| 655 |
batch, seq, dim = x.shape
|
|
@@ -672,7 +664,7 @@ class TransformerAudioProjector(nn.Module):
|
|
| 672 |
if self.blocks is not None:
|
| 673 |
x = self.blocks(x)
|
| 674 |
|
| 675 |
-
return
|
| 676 |
|
| 677 |
def get_output_length(self, input_length: int) -> int:
|
| 678 |
return (input_length - 1) // self.k + 1
|
|
|
|
| 605 |
|
| 606 |
def __init__(self, config):
|
| 607 |
super().__init__()
|
| 608 |
+
# Default stride 6: Whisper (2x) * Projector (6x) = 12x total → ~8 Hz
|
| 609 |
+
# Matches FunASR's total stride (6x encoder * 2x projector = 12x)
|
| 610 |
+
self.k = getattr(config, "projector_pool_stride", 6)
|
| 611 |
|
| 612 |
encoder_dim = config.encoder_dim
|
| 613 |
llm_dim = config.llm_dim
|
|
|
|
| 615 |
# Input: Stacked frames (e.g. 1280 * 2 = 2560)
|
| 616 |
in_dim = encoder_dim * self.k
|
| 617 |
|
| 618 |
+
# FFN hidden dim for initial projection (balanced compression)
|
| 619 |
+
# 7680 → 4096 → 2048 distributes compression evenly (~2x each layer)
|
| 620 |
+
ffn_dim = getattr(config, "projector_hidden_dim", None) or 4096
|
| 621 |
|
| 622 |
# FunASR-style projection: linear1 -> relu -> linear2
|
| 623 |
self.linear1 = nn.Linear(in_dim, ffn_dim)
|
|
|
|
| 630 |
encoder_layer = nn.TransformerEncoderLayer(
|
| 631 |
d_model=llm_dim,
|
| 632 |
nhead=getattr(config, "projector_num_heads", 8),
|
| 633 |
+
dim_feedforward=1024, # Match FunASR (audio complexity is LLM-independent)
|
| 634 |
dropout=0.0,
|
| 635 |
activation="relu",
|
| 636 |
batch_first=True,
|
| 637 |
norm_first=True,
|
| 638 |
)
|
| 639 |
+
self.blocks = nn.TransformerEncoder(
|
| 640 |
+
encoder_layer, num_layers=num_layers, enable_nested_tensor=False
|
| 641 |
+
)
|
| 642 |
else:
|
| 643 |
self.blocks = None
|
| 644 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 645 |
def forward(self, x):
|
| 646 |
# x: [Batch, Seq, Dim]
|
| 647 |
batch, seq, dim = x.shape
|
|
|
|
| 664 |
if self.blocks is not None:
|
| 665 |
x = self.blocks(x)
|
| 666 |
|
| 667 |
+
return x
|
| 668 |
|
| 669 |
def get_output_length(self, input_length: int) -> int:
|
| 670 |
return (input_length - 1) // self.k + 1
|