mazesmazes commited on
Commit
8329143
·
verified ·
1 Parent(s): 0366e08

Training in progress - step 1000

Browse files
Files changed (1) hide show
  1. projectors.py +37 -84
projectors.py CHANGED
@@ -1,10 +1,9 @@
1
  """Audio projector modules for bridging encoder and decoder embeddings.
2
 
3
  This module contains all projector architectures:
4
- - LinearProjector: Simple avg pool + linear (Chinese Dialects paper, best for Stage 1)
5
- - MLPAudioProjector: 2-layer MLP with frame stacking downsampling
6
- - MOSAProjector: MOSA-style dense mixture of experts (arXiv:2508.18998)
7
- - MoEAudioProjector: Shared expert + sparse routed experts
8
  - QFormerAudioProjector: BLIP-2 QFormer with learnable queries (Granite-style)
9
  """
10
 
@@ -16,51 +15,6 @@ import torch.nn.functional as F # noqa: N812
16
  from transformers import AutoModel, Blip2QFormerConfig
17
  from transformers.models.llama.modeling_llama import LlamaRMSNorm
18
 
19
- # =============================================================================
20
- # Linear Projector (Chinese Dialects paper style)
21
- # =============================================================================
22
-
23
-
24
- class LinearProjector(nn.Module):
25
- """Simple linear projector with average pooling downsampling.
26
-
27
- Based on Chinese Dialects paper (arXiv:2505.21138) which found this
28
- outperformed Conv1D, Transformer, and Q-Former in Stage 1 (projector-only).
29
-
30
- Architecture: AvgPool(4x) -> Linear(encoder_dim, llm_dim)
31
- """
32
-
33
- def __init__(self, config):
34
- super().__init__()
35
-
36
- self.encoder_dim = getattr(config, "encoder_dim", 768)
37
- self.llm_dim = getattr(config, "llm_dim", 2048)
38
- self.pool_stride = getattr(config, "projector_pool_stride", 4) # 4x = 12.5Hz
39
-
40
- # Single linear projection (no hidden layers, no activation)
41
- self.linear = nn.Linear(self.encoder_dim, self.llm_dim)
42
-
43
- def get_output_length(self, input_length: int) -> int:
44
- """Calculate output sequence length given input length."""
45
- return input_length // self.pool_stride
46
-
47
- def forward(self, x: torch.Tensor) -> torch.Tensor:
48
- """
49
- Args:
50
- x: [batch, seq_len, encoder_dim]
51
-
52
- Returns:
53
- [batch, seq_len // pool_stride, llm_dim]
54
- """
55
- # Average pooling for downsampling (better than frame stacking for linear)
56
- # Transpose for avg_pool1d: [B, S, D] -> [B, D, S]
57
- x = x.transpose(1, 2)
58
- x = F.avg_pool1d(x, kernel_size=self.pool_stride, stride=self.pool_stride)
59
- x = x.transpose(1, 2) # [B, S//k, D]
60
-
61
- return self.linear(x)
62
-
63
-
64
  # =============================================================================
65
  # MLP Projector
66
  # =============================================================================
@@ -85,18 +39,21 @@ class MLPAudioProjector(nn.Module):
85
  self.linear_2 = nn.Linear(hidden_dim, llm_dim)
86
 
87
  def get_output_length(self, input_length: int) -> int:
88
- """Calculate output sequence length given input length."""
89
- return input_length // self.k
 
90
 
91
  def forward(self, x):
92
  """
93
  x: [Batch, Seq_Len, Dim]
94
- Returns: [Batch, Seq_Len // k, llm_dim]
95
  """
96
  batch, seq, dim = x.shape
97
- # Reshape to combine k frames: [B, S, D] -> [B, -1, D*k]
98
- # -1 infers sequence length, implicitly downsampling by factor k
99
- x = x.reshape(batch, -1, dim * self.k)
 
 
100
 
101
  x = self.linear_1(x)
102
  x = self.act(x)
@@ -109,12 +66,12 @@ class MLPAudioProjector(nn.Module):
109
 
110
 
111
  class SimpleAdapter(nn.Module):
112
- """Simple 2-layer ReLU adapter (from MOSA paper, arXiv:2508.18998)."""
113
 
114
  def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
115
  super().__init__()
116
  self.fc1 = nn.Linear(input_dim, hidden_dim)
117
- self.act = nn.ReLU()
118
  self.fc2 = nn.Linear(hidden_dim, output_dim)
119
 
120
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -139,27 +96,19 @@ class MOSAProjector(nn.Module):
139
 
140
  Based on "MOSA: Mixtures of Simple Adapters" (arXiv:2508.18998).
141
  Uses softmax gating over all experts (dense MoE) with only cross-entropy loss.
142
- Uses conv-based downsampling (2x Conv1d stride-2) as described in the paper.
143
  """
144
 
145
  def __init__(self, config):
146
  super().__init__()
147
  self.encoder_dim = getattr(config, "encoder_dim", None) or 1280
148
  self.llm_dim = getattr(config, "llm_dim", None) or 2048
 
149
  self.num_experts = getattr(config, "num_experts", None) or 4 # MOSA-Base uses 4
150
  adapter_hidden = getattr(config, "adapter_hidden_dim", None) or 4096
151
 
152
- # --- Conv-based downsampling (paper: 2x Conv1d, kernel=3, stride=2) ---
153
- # Total 4x downsampling: 50Hz -> 12.5Hz
154
- self.conv_downsample = nn.Sequential(
155
- nn.Conv1d(self.encoder_dim, self.encoder_dim, kernel_size=3, stride=2, padding=1),
156
- nn.ReLU(),
157
- nn.Conv1d(self.encoder_dim, self.encoder_dim, kernel_size=3, stride=2, padding=1),
158
- nn.ReLU(),
159
- )
160
-
161
- # Input dim to adapters is now just encoder_dim (not encoder_dim * k)
162
- in_dim = self.encoder_dim
163
 
164
  # --- 1. Simple Router (MOSA-Base: 2 layers with ReLU) ---
165
  # Maps encoder_dim -> 512 -> num_experts
@@ -171,7 +120,7 @@ class MOSAProjector(nn.Module):
171
  )
172
 
173
  # --- 2. Experts (Simple 2-layer GELU adapters) ---
174
- # Each expert: encoder_dim -> hidden -> llm_dim
175
  self.experts = nn.ModuleList(
176
  [SimpleAdapter(in_dim, adapter_hidden, self.llm_dim) for _ in range(self.num_experts)]
177
  )
@@ -180,27 +129,32 @@ class MOSAProjector(nn.Module):
180
  # x: (B, S, encoder_dim)
181
  batch_size, seq_len, dim = x.shape
182
 
183
- # --- 1. Conv downsampling ---
184
- # Conv1d expects (B, C, S), so transpose
185
- x_conv = x.transpose(1, 2) # (B, D, S)
186
- x_conv = self.conv_downsample(x_conv) # (B, D, S//4)
187
- x_downsampled = x_conv.transpose(1, 2) # (B, S//4, D)
 
 
188
 
189
- # --- 2. Router Branch ---
190
- # Router operates on downsampled features
191
- routing_weights = F.softmax(self.router(x_downsampled), dim=-1) # (B, S//4, num_experts)
 
 
 
192
 
193
  # --- 3. Expert Mixture (Dense Execution) ---
194
  # Run all experts and compute weighted sum
195
  expert_outputs = torch.stack(
196
- [expert(x_downsampled) for expert in self.experts]
197
- ) # (E, B, S//4, D)
198
  return torch.einsum("ebsd, bse -> bsd", expert_outputs, routing_weights)
199
 
200
  def get_output_length(self, input_length: int) -> int:
201
- """Calculate output sequence length given input length."""
202
- # Two stride-2 convs = 4x downsampling
203
- return input_length // 4
204
 
205
 
206
  # =============================================================================
@@ -497,7 +451,6 @@ class QFormerAudioProjector(nn.Module):
497
  # =============================================================================
498
 
499
  PROJECTOR_CLASSES = {
500
- "linear": LinearProjector,
501
  "mlp": MLPAudioProjector,
502
  "mosa": MOSAProjector,
503
  "moe": MoEAudioProjector,
 
1
  """Audio projector modules for bridging encoder and decoder embeddings.
2
 
3
  This module contains all projector architectures:
4
+ - MLPAudioProjector: Simple 2-layer MLP with frame stacking downsampling
5
+ - MOSAProjector: MOSA-style dense mixture of experts
6
+ - SharedMoEAudioProjector: Shared expert + sparse routed experts
 
7
  - QFormerAudioProjector: BLIP-2 QFormer with learnable queries (Granite-style)
8
  """
9
 
 
15
  from transformers import AutoModel, Blip2QFormerConfig
16
  from transformers.models.llama.modeling_llama import LlamaRMSNorm
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  # =============================================================================
19
  # MLP Projector
20
  # =============================================================================
 
39
  self.linear_2 = nn.Linear(hidden_dim, llm_dim)
40
 
41
  def get_output_length(self, input_length: int) -> int:
42
+ """Calculate output sequence length given input length (matches GLM-ASR)."""
43
+ # GLM-ASR formula: (L - merge_factor) // merge_factor + 1
44
+ return (input_length - self.k) // self.k + 1
45
 
46
  def forward(self, x):
47
  """
48
  x: [Batch, Seq_Len, Dim]
49
+ Returns: [Batch, (Seq_Len - k) // k + 1, llm_dim]
50
  """
51
  batch, seq, dim = x.shape
52
+ # Truncate to match GLM-ASR: use (seq - k) // k + 1 frames
53
+ # This drops trailing frames that don't fill a complete k-frame window
54
+ out_len = (seq - self.k) // self.k + 1
55
+ x = x[:, : out_len * self.k, :] # Truncate to exact multiple
56
+ x = x.reshape(batch, out_len, dim * self.k)
57
 
58
  x = self.linear_1(x)
59
  x = self.act(x)
 
66
 
67
 
68
  class SimpleAdapter(nn.Module):
69
+ """Simple 2-layer GELU adapter (from MOSA paper)."""
70
 
71
  def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
72
  super().__init__()
73
  self.fc1 = nn.Linear(input_dim, hidden_dim)
74
+ self.act = nn.GELU()
75
  self.fc2 = nn.Linear(hidden_dim, output_dim)
76
 
77
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
96
 
97
  Based on "MOSA: Mixtures of Simple Adapters" (arXiv:2508.18998).
98
  Uses softmax gating over all experts (dense MoE) with only cross-entropy loss.
99
+ Uses frame-stacking for downsampling (like MLP projector).
100
  """
101
 
102
  def __init__(self, config):
103
  super().__init__()
104
  self.encoder_dim = getattr(config, "encoder_dim", None) or 1280
105
  self.llm_dim = getattr(config, "llm_dim", None) or 2048
106
+ self.k = getattr(config, "projector_pool_stride", 4)
107
  self.num_experts = getattr(config, "num_experts", None) or 4 # MOSA-Base uses 4
108
  adapter_hidden = getattr(config, "adapter_hidden_dim", None) or 4096
109
 
110
+ # Frame stacking: concat k adjacent frames then project
111
+ in_dim = self.encoder_dim * self.k
 
 
 
 
 
 
 
 
 
112
 
113
  # --- 1. Simple Router (MOSA-Base: 2 layers with ReLU) ---
114
  # Maps encoder_dim -> 512 -> num_experts
 
120
  )
121
 
122
  # --- 2. Experts (Simple 2-layer GELU adapters) ---
123
+ # Each expert: in_dim (stacked frames) -> hidden -> llm_dim
124
  self.experts = nn.ModuleList(
125
  [SimpleAdapter(in_dim, adapter_hidden, self.llm_dim) for _ in range(self.num_experts)]
126
  )
 
129
  # x: (B, S, encoder_dim)
130
  batch_size, seq_len, dim = x.shape
131
 
132
+ # Truncate to match GLM-ASR: use (seq - k) // k + 1 frames
133
+ out_len = (seq_len - self.k) // self.k + 1
134
+ x = x[:, : out_len * self.k, :]
135
+
136
+ # --- 1. Router Branch ---
137
+ # Mean pool encoder outputs for routing decisions
138
+ x_pooled = x.reshape(batch_size, out_len, self.k, self.encoder_dim).mean(dim=2) # (B, out_len, D)
139
 
140
+ # Router logits and softmax gating (dense MoE)
141
+ routing_weights = F.softmax(self.router(x_pooled), dim=-1) # (B, out_len, num_experts)
142
+
143
+ # --- 2. Frame stacking for experts ---
144
+ # Reshape to combine k frames: [B, S, D] -> [B, out_len, D*k]
145
+ x_stacked = x.reshape(batch_size, out_len, dim * self.k)
146
 
147
  # --- 3. Expert Mixture (Dense Execution) ---
148
  # Run all experts and compute weighted sum
149
  expert_outputs = torch.stack(
150
+ [expert(x_stacked) for expert in self.experts]
151
+ ) # (E, B, out_len, D)
152
  return torch.einsum("ebsd, bse -> bsd", expert_outputs, routing_weights)
153
 
154
  def get_output_length(self, input_length: int) -> int:
155
+ """Calculate output sequence length given input length (matches GLM-ASR)."""
156
+ # GLM-ASR formula: (L - merge_factor) // merge_factor + 1
157
+ return (input_length - self.k) // self.k + 1
158
 
159
 
160
  # =============================================================================
 
451
  # =============================================================================
452
 
453
  PROJECTOR_CLASSES = {
 
454
  "mlp": MLPAudioProjector,
455
  "mosa": MOSAProjector,
456
  "moe": MoEAudioProjector,