mazesmazes commited on
Commit
cae87d2
·
verified ·
1 Parent(s): dbdcadd

Training in progress - step 500

Browse files
Files changed (2) hide show
  1. asr_pipeline.py +25 -0
  2. projectors.py +46 -104
asr_pipeline.py CHANGED
@@ -476,10 +476,35 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
476
  text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
477
  # Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
478
  text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
 
 
479
  # Truncate if a word repeats more than 3 times consecutively
480
  text = self._truncate_repetitions(text, max_repeats=3)
481
  return {"text": text}
482
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
483
  def _truncate_repetitions(self, text: str, max_repeats: int = 3) -> str:
484
  """Truncate text when a word repeats more than max_repeats times consecutively.
485
 
 
476
  text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
477
  # Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
478
  text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
479
+ # Collapse spaced-out acronyms (e.g., "I S D S" -> "ISDS")
480
+ text = self._collapse_acronyms(text)
481
  # Truncate if a word repeats more than 3 times consecutively
482
  text = self._truncate_repetitions(text, max_repeats=3)
483
  return {"text": text}
484
 
485
+ def _collapse_acronyms(self, text: str) -> str:
486
+ """Collapse spaced-out acronyms into single words.
487
+
488
+ Converts patterns like "I S D S" to "ISDS" when 2+ single letters
489
+ are separated by spaces.
490
+
491
+ Args:
492
+ text: Input text with potential spaced acronyms
493
+
494
+ Returns:
495
+ Text with acronyms collapsed
496
+ """
497
+ # Match 2+ single letters (case-insensitive) separated by spaces
498
+ # Pattern: single letter, then one or more (space + single letter)
499
+ pattern = r"\b([A-Za-z])((?:\s[A-Za-z]){1,})\b"
500
+
501
+ def collapse_match(match: re.Match) -> str:
502
+ # Get the full match and remove spaces
503
+ full = match.group(0)
504
+ return full.replace(" ", "").upper()
505
+
506
+ return re.sub(pattern, collapse_match, text)
507
+
508
  def _truncate_repetitions(self, text: str, max_repeats: int = 3) -> str:
509
  """Truncate text when a word repeats more than max_repeats times consecutively.
510
 
projectors.py CHANGED
@@ -89,124 +89,68 @@ class SwiGLUExpert(nn.Module):
89
 
90
 
91
  class MOSAProjector(nn.Module):
 
 
 
 
 
 
 
92
  def __init__(self, config):
93
  super().__init__()
94
  self.encoder_dim = getattr(config, "encoder_dim", None) or 1280
95
  self.llm_dim = getattr(config, "llm_dim", None) or 2048
96
- self.num_experts = getattr(config, "num_experts", None) or 8
 
97
  adapter_hidden = getattr(config, "adapter_hidden_dim", None) or 4096
98
 
99
- # Auxiliary loss coefficients (MOSA paper uses only cross-entropy, no aux losses)
100
- self.aux_loss_coef = getattr(config, "router_aux_loss_coef", 0.0)
101
- self.z_loss_coef = getattr(config, "router_z_loss_coef", 0.0)
102
-
103
- # Store router state for aux loss computation
104
- self.last_router_logits = None
105
- self.last_routing_weights = None
106
-
107
- # --- 1. Pre-Norms (CRITICAL for stability) ---
108
- self.in_norm = LlamaRMSNorm(self.encoder_dim, eps=1e-8)
109
-
110
- # --- 2. Convolutional Subsampling (Stride 4) ---
111
- self.conv = nn.Sequential(
112
- nn.Conv1d(self.encoder_dim, self.llm_dim, kernel_size=3, stride=2, padding=1),
113
- nn.SiLU(),
114
- nn.Conv1d(self.llm_dim, self.llm_dim, kernel_size=3, stride=2, padding=1),
115
- nn.SiLU(),
116
- )
117
 
118
- # --- 3. Deep Router (ReLU per MOSA paper) ---
 
 
119
  self.router = nn.Sequential(
120
- nn.Linear(self.encoder_dim, 2560),
121
- nn.ReLU(),
122
- nn.Linear(2560, 5120),
123
- nn.ReLU(),
124
- nn.Linear(5120, 2560),
125
  nn.ReLU(),
126
- nn.Linear(2560, 1280),
127
- nn.ReLU(),
128
- nn.Linear(1280, self.num_experts),
129
  )
130
 
131
- # --- 4. Experts (Simple 2-layer ReLU adapters per MOSA paper) ---
 
132
  self.experts = nn.ModuleList(
133
- [
134
- SimpleAdapter(self.llm_dim, adapter_hidden, self.llm_dim)
135
- for _ in range(self.num_experts)
136
- ]
137
  )
138
 
139
- # --- 5. Output Norm ---
140
- # Projects often drift in magnitude; this clamps them before the LLM.
141
- self.out_norm = LlamaRMSNorm(self.llm_dim, eps=1e-8)
142
-
143
- # Using PyTorch default initialization (like MOSA paper)
144
-
145
  def forward(self, x):
146
- # x: (B, S, 1280)
147
- batch_size, seq_len, _ = x.shape
148
 
149
- # Apply Input Norm
150
- x = self.in_norm(x)
 
151
 
152
- # --- 1. Conv Branch ---
153
- x_trans = x.permute(0, 2, 1) # (B, D, S)
154
- h_conv = self.conv(x_trans).permute(0, 2, 1) # (B, S//4, llm_dim)
155
 
156
- # --- 2. Router Branch ---
157
- pad_amt = (4 - (seq_len % 4)) % 4
158
- x_padded = F.pad(x, (0, 0, 0, pad_amt)) if pad_amt > 0 else x
159
-
160
- # Mean pool to align receptive fields
161
- x_pooled = x_padded.view(batch_size, -1, 4, self.encoder_dim).mean(dim=2) # (B, S//4, D)
162
-
163
- # Router Logits
164
- router_logits = self.router(x_pooled) # (B, S//4, num_experts)
165
-
166
- # Softmax for Dense MoE (Soft Mixing)
167
- routing_weights = F.softmax(router_logits, dim=-1)
168
-
169
- # Store for aux loss computation
170
- self.last_router_logits = router_logits
171
- self.last_routing_weights = routing_weights
172
 
173
  # --- 3. Expert Mixture (Dense Execution) ---
174
- # Warning: High VRAM usage. Runs all experts.
175
- # h_conv: (B, S//4, llm_dim)
176
-
177
- # Stack approach is clean but memory hungry.
178
- # Checkpointing could be added here if OOM occurs.
179
- expert_outputs = torch.stack([expert(h_conv) for expert in self.experts]) # (E, B, S//4, D)
180
-
181
- # Weighted Sum
182
- # (Experts, Batch, Seq, Dim) * (Batch, Seq, Experts) -> (Batch, Seq, Dim)
183
- final_out = torch.einsum("ebsd, bse -> bsd", expert_outputs, routing_weights)
184
-
185
- return self.out_norm(final_out)
186
 
187
  def get_output_length(self, input_length: int) -> int:
188
  """Calculate output sequence length given input length."""
189
- # Two conv layers with stride=2 each = stride 4 total
190
- padded = input_length + (4 - input_length % 4) % 4
191
- return padded // 4
192
-
193
- def get_aux_loss(self) -> torch.Tensor:
194
- """Compute auxiliary losses: load balancing + z-loss."""
195
- if self.last_router_logits is None:
196
- return torch.tensor(0.0, device=self.conv[0].weight.device)
197
-
198
- # Flatten for loss computation: (B, S, E) -> (B*S, E)
199
- logits_flat = self.last_router_logits.view(-1, self.num_experts)
200
- probs_flat = self.last_routing_weights.view(-1, self.num_experts)
201
-
202
- balance = load_balancing_loss(probs_flat, self.num_experts, top_k=self.num_experts)
203
- z = z_loss(logits_flat)
204
-
205
- return self.aux_loss_coef * balance + self.z_loss_coef * z
206
 
207
 
208
  # =============================================================================
209
- # Shared MoE Projector
210
  # =============================================================================
211
 
212
 
@@ -232,9 +176,9 @@ class SharedMoEBlock(nn.Module):
232
  self.router = nn.Linear(input_dim, num_experts, bias=False)
233
  nn.init.normal_(self.router.weight, mean=0.0, std=0.02)
234
 
235
- self.shared_expert = SwiGLUExpert(input_dim, hidden_dim, output_dim)
236
  self.experts = nn.ModuleList(
237
- [SwiGLUExpert(input_dim, hidden_dim, output_dim) for _ in range(num_experts)]
238
  )
239
 
240
  self.last_router_logits = None
@@ -307,8 +251,8 @@ def z_loss(router_logits: torch.Tensor) -> torch.Tensor:
307
  return torch.logsumexp(router_logits.float(), dim=-1).square().mean()
308
 
309
 
310
- class SharedMoEAudioProjector(nn.Module):
311
- """Shared expert + sparse routed experts projector."""
312
 
313
  def __init__(self, config):
314
  super().__init__()
@@ -335,14 +279,12 @@ class SharedMoEAudioProjector(nn.Module):
335
 
336
  def _init_weights(self):
337
  with torch.no_grad():
338
- nn.init.orthogonal_(self.moe.shared_expert.gate_proj.weight)
339
- nn.init.orthogonal_(self.moe.shared_expert.up_proj.weight)
340
- nn.init.orthogonal_(self.moe.shared_expert.down_proj.weight, gain=0.5)
341
 
342
  for expert in self.moe.experts:
343
- nn.init.orthogonal_(expert.gate_proj.weight)
344
- nn.init.orthogonal_(expert.up_proj.weight)
345
- nn.init.orthogonal_(expert.down_proj.weight, gain=0.01)
346
 
347
  def get_output_length(self, input_length: int) -> int:
348
  """Calculate output sequence length given input length."""
@@ -354,7 +296,7 @@ class SharedMoEAudioProjector(nn.Module):
354
  def forward(self, x: torch.Tensor) -> torch.Tensor:
355
  batch_size, seq_len, dim = x.size()
356
 
357
- target_dtype = self.moe.shared_expert.gate_proj.weight.dtype
358
  if x.dtype != target_dtype:
359
  x = x.to(target_dtype)
360
 
@@ -503,6 +445,6 @@ class QFormerAudioProjector(nn.Module):
503
  PROJECTOR_CLASSES = {
504
  "mlp": MLPAudioProjector,
505
  "mosa": MOSAProjector,
506
- "shared_moe": SharedMoEAudioProjector,
507
  "qformer": QFormerAudioProjector,
508
  }
 
89
 
90
 
91
  class MOSAProjector(nn.Module):
92
+ """MOSA-Base projector: simple 2-layer router with 4 simple adapters.
93
+
94
+ Based on "MOSA: Mixtures of Simple Adapters" (arXiv:2508.18998).
95
+ Uses softmax gating over all experts (dense MoE) with only cross-entropy loss.
96
+ Uses frame-stacking for downsampling (like MLP projector).
97
+ """
98
+
99
  def __init__(self, config):
100
  super().__init__()
101
  self.encoder_dim = getattr(config, "encoder_dim", None) or 1280
102
  self.llm_dim = getattr(config, "llm_dim", None) or 2048
103
+ self.k = getattr(config, "projector_pool_stride", 4)
104
+ self.num_experts = getattr(config, "num_experts", None) or 4 # MOSA-Base uses 4
105
  adapter_hidden = getattr(config, "adapter_hidden_dim", None) or 4096
106
 
107
+ # Frame stacking: concat k adjacent frames then project
108
+ in_dim = self.encoder_dim * self.k
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
+ # --- 1. Simple Router (MOSA-Base: 2 layers with ReLU) ---
111
+ # Maps encoder_dim -> 512 -> num_experts
112
+ router_hidden = getattr(config, "router_hidden_dim", None) or 512
113
  self.router = nn.Sequential(
114
+ nn.Linear(self.encoder_dim, router_hidden),
 
 
 
 
115
  nn.ReLU(),
116
+ nn.Linear(router_hidden, self.num_experts),
 
 
117
  )
118
 
119
+ # --- 2. Experts (Simple 2-layer ReLU adapters per MOSA paper) ---
120
+ # Each expert: in_dim (stacked frames) -> hidden -> llm_dim
121
  self.experts = nn.ModuleList(
122
+ [SimpleAdapter(in_dim, adapter_hidden, self.llm_dim) for _ in range(self.num_experts)]
 
 
 
123
  )
124
 
 
 
 
 
 
 
125
  def forward(self, x):
126
+ # x: (B, S, encoder_dim)
127
+ batch_size, seq_len, dim = x.shape
128
 
129
+ # --- 1. Router Branch ---
130
+ # Mean pool encoder outputs for routing decisions
131
+ x_pooled = x.reshape(batch_size, -1, self.k, self.encoder_dim).mean(dim=2) # (B, S//k, D)
132
 
133
+ # Router logits and softmax gating (dense MoE)
134
+ routing_weights = F.softmax(self.router(x_pooled), dim=-1) # (B, S//k, num_experts)
 
135
 
136
+ # --- 2. Frame stacking for experts ---
137
+ # Reshape to combine k frames: [B, S, D] -> [B, S//k, D*k]
138
+ x_stacked = x.reshape(batch_size, -1, dim * self.k)
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  # --- 3. Expert Mixture (Dense Execution) ---
141
+ # Run all experts and compute weighted sum
142
+ expert_outputs = torch.stack(
143
+ [expert(x_stacked) for expert in self.experts]
144
+ ) # (E, B, S//k, D)
145
+ return torch.einsum("ebsd, bse -> bsd", expert_outputs, routing_weights)
 
 
 
 
 
 
 
146
 
147
  def get_output_length(self, input_length: int) -> int:
148
  """Calculate output sequence length given input length."""
149
+ return input_length // self.k
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
 
152
  # =============================================================================
153
+ # MoE Projector (Shared Expert + Sparse Routed Experts)
154
  # =============================================================================
155
 
156
 
 
176
  self.router = nn.Linear(input_dim, num_experts, bias=False)
177
  nn.init.normal_(self.router.weight, mean=0.0, std=0.02)
178
 
179
+ self.shared_expert = SimpleAdapter(input_dim, hidden_dim, output_dim)
180
  self.experts = nn.ModuleList(
181
+ [SimpleAdapter(input_dim, hidden_dim, output_dim) for _ in range(num_experts)]
182
  )
183
 
184
  self.last_router_logits = None
 
251
  return torch.logsumexp(router_logits.float(), dim=-1).square().mean()
252
 
253
 
254
+ class MoEAudioProjector(nn.Module):
255
+ """MoE projector with shared expert + sparse routed experts."""
256
 
257
  def __init__(self, config):
258
  super().__init__()
 
279
 
280
  def _init_weights(self):
281
  with torch.no_grad():
282
+ nn.init.orthogonal_(self.moe.shared_expert.fc1.weight)
283
+ nn.init.orthogonal_(self.moe.shared_expert.fc2.weight, gain=0.5)
 
284
 
285
  for expert in self.moe.experts:
286
+ nn.init.orthogonal_(expert.fc1.weight)
287
+ nn.init.orthogonal_(expert.fc2.weight, gain=0.01)
 
288
 
289
  def get_output_length(self, input_length: int) -> int:
290
  """Calculate output sequence length given input length."""
 
296
  def forward(self, x: torch.Tensor) -> torch.Tensor:
297
  batch_size, seq_len, dim = x.size()
298
 
299
+ target_dtype = self.moe.shared_expert.fc1.weight.dtype
300
  if x.dtype != target_dtype:
301
  x = x.to(target_dtype)
302
 
 
445
  PROJECTOR_CLASSES = {
446
  "mlp": MLPAudioProjector,
447
  "mosa": MOSAProjector,
448
+ "moe": MoEAudioProjector,
449
  "qformer": QFormerAudioProjector,
450
  }