Training in progress - step 500
Browse files- asr_pipeline.py +25 -0
- projectors.py +46 -104
asr_pipeline.py
CHANGED
|
@@ -476,10 +476,35 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 476 |
text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
|
| 477 |
# Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
|
| 478 |
text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
|
|
|
|
|
|
|
| 479 |
# Truncate if a word repeats more than 3 times consecutively
|
| 480 |
text = self._truncate_repetitions(text, max_repeats=3)
|
| 481 |
return {"text": text}
|
| 482 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 483 |
def _truncate_repetitions(self, text: str, max_repeats: int = 3) -> str:
|
| 484 |
"""Truncate text when a word repeats more than max_repeats times consecutively.
|
| 485 |
|
|
|
|
| 476 |
text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
|
| 477 |
# Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
|
| 478 |
text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
|
| 479 |
+
# Collapse spaced-out acronyms (e.g., "I S D S" -> "ISDS")
|
| 480 |
+
text = self._collapse_acronyms(text)
|
| 481 |
# Truncate if a word repeats more than 3 times consecutively
|
| 482 |
text = self._truncate_repetitions(text, max_repeats=3)
|
| 483 |
return {"text": text}
|
| 484 |
|
| 485 |
+
def _collapse_acronyms(self, text: str) -> str:
|
| 486 |
+
"""Collapse spaced-out acronyms into single words.
|
| 487 |
+
|
| 488 |
+
Converts patterns like "I S D S" to "ISDS" when 2+ single letters
|
| 489 |
+
are separated by spaces.
|
| 490 |
+
|
| 491 |
+
Args:
|
| 492 |
+
text: Input text with potential spaced acronyms
|
| 493 |
+
|
| 494 |
+
Returns:
|
| 495 |
+
Text with acronyms collapsed
|
| 496 |
+
"""
|
| 497 |
+
# Match 2+ single letters (case-insensitive) separated by spaces
|
| 498 |
+
# Pattern: single letter, then one or more (space + single letter)
|
| 499 |
+
pattern = r"\b([A-Za-z])((?:\s[A-Za-z]){1,})\b"
|
| 500 |
+
|
| 501 |
+
def collapse_match(match: re.Match) -> str:
|
| 502 |
+
# Get the full match and remove spaces
|
| 503 |
+
full = match.group(0)
|
| 504 |
+
return full.replace(" ", "").upper()
|
| 505 |
+
|
| 506 |
+
return re.sub(pattern, collapse_match, text)
|
| 507 |
+
|
| 508 |
def _truncate_repetitions(self, text: str, max_repeats: int = 3) -> str:
|
| 509 |
"""Truncate text when a word repeats more than max_repeats times consecutively.
|
| 510 |
|
projectors.py
CHANGED
|
@@ -89,124 +89,68 @@ class SwiGLUExpert(nn.Module):
|
|
| 89 |
|
| 90 |
|
| 91 |
class MOSAProjector(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
def __init__(self, config):
|
| 93 |
super().__init__()
|
| 94 |
self.encoder_dim = getattr(config, "encoder_dim", None) or 1280
|
| 95 |
self.llm_dim = getattr(config, "llm_dim", None) or 2048
|
| 96 |
-
self.
|
|
|
|
| 97 |
adapter_hidden = getattr(config, "adapter_hidden_dim", None) or 4096
|
| 98 |
|
| 99 |
-
#
|
| 100 |
-
|
| 101 |
-
self.z_loss_coef = getattr(config, "router_z_loss_coef", 0.0)
|
| 102 |
-
|
| 103 |
-
# Store router state for aux loss computation
|
| 104 |
-
self.last_router_logits = None
|
| 105 |
-
self.last_routing_weights = None
|
| 106 |
-
|
| 107 |
-
# --- 1. Pre-Norms (CRITICAL for stability) ---
|
| 108 |
-
self.in_norm = LlamaRMSNorm(self.encoder_dim, eps=1e-8)
|
| 109 |
-
|
| 110 |
-
# --- 2. Convolutional Subsampling (Stride 4) ---
|
| 111 |
-
self.conv = nn.Sequential(
|
| 112 |
-
nn.Conv1d(self.encoder_dim, self.llm_dim, kernel_size=3, stride=2, padding=1),
|
| 113 |
-
nn.SiLU(),
|
| 114 |
-
nn.Conv1d(self.llm_dim, self.llm_dim, kernel_size=3, stride=2, padding=1),
|
| 115 |
-
nn.SiLU(),
|
| 116 |
-
)
|
| 117 |
|
| 118 |
-
# ---
|
|
|
|
|
|
|
| 119 |
self.router = nn.Sequential(
|
| 120 |
-
nn.Linear(self.encoder_dim,
|
| 121 |
-
nn.ReLU(),
|
| 122 |
-
nn.Linear(2560, 5120),
|
| 123 |
-
nn.ReLU(),
|
| 124 |
-
nn.Linear(5120, 2560),
|
| 125 |
nn.ReLU(),
|
| 126 |
-
nn.Linear(
|
| 127 |
-
nn.ReLU(),
|
| 128 |
-
nn.Linear(1280, self.num_experts),
|
| 129 |
)
|
| 130 |
|
| 131 |
-
# ---
|
|
|
|
| 132 |
self.experts = nn.ModuleList(
|
| 133 |
-
[
|
| 134 |
-
SimpleAdapter(self.llm_dim, adapter_hidden, self.llm_dim)
|
| 135 |
-
for _ in range(self.num_experts)
|
| 136 |
-
]
|
| 137 |
)
|
| 138 |
|
| 139 |
-
# --- 5. Output Norm ---
|
| 140 |
-
# Projects often drift in magnitude; this clamps them before the LLM.
|
| 141 |
-
self.out_norm = LlamaRMSNorm(self.llm_dim, eps=1e-8)
|
| 142 |
-
|
| 143 |
-
# Using PyTorch default initialization (like MOSA paper)
|
| 144 |
-
|
| 145 |
def forward(self, x):
|
| 146 |
-
# x: (B, S,
|
| 147 |
-
batch_size, seq_len,
|
| 148 |
|
| 149 |
-
#
|
| 150 |
-
|
|
|
|
| 151 |
|
| 152 |
-
#
|
| 153 |
-
|
| 154 |
-
h_conv = self.conv(x_trans).permute(0, 2, 1) # (B, S//4, llm_dim)
|
| 155 |
|
| 156 |
-
# --- 2.
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
# Mean pool to align receptive fields
|
| 161 |
-
x_pooled = x_padded.view(batch_size, -1, 4, self.encoder_dim).mean(dim=2) # (B, S//4, D)
|
| 162 |
-
|
| 163 |
-
# Router Logits
|
| 164 |
-
router_logits = self.router(x_pooled) # (B, S//4, num_experts)
|
| 165 |
-
|
| 166 |
-
# Softmax for Dense MoE (Soft Mixing)
|
| 167 |
-
routing_weights = F.softmax(router_logits, dim=-1)
|
| 168 |
-
|
| 169 |
-
# Store for aux loss computation
|
| 170 |
-
self.last_router_logits = router_logits
|
| 171 |
-
self.last_routing_weights = routing_weights
|
| 172 |
|
| 173 |
# --- 3. Expert Mixture (Dense Execution) ---
|
| 174 |
-
#
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
#
|
| 178 |
-
|
| 179 |
-
expert_outputs = torch.stack([expert(h_conv) for expert in self.experts]) # (E, B, S//4, D)
|
| 180 |
-
|
| 181 |
-
# Weighted Sum
|
| 182 |
-
# (Experts, Batch, Seq, Dim) * (Batch, Seq, Experts) -> (Batch, Seq, Dim)
|
| 183 |
-
final_out = torch.einsum("ebsd, bse -> bsd", expert_outputs, routing_weights)
|
| 184 |
-
|
| 185 |
-
return self.out_norm(final_out)
|
| 186 |
|
| 187 |
def get_output_length(self, input_length: int) -> int:
|
| 188 |
"""Calculate output sequence length given input length."""
|
| 189 |
-
|
| 190 |
-
padded = input_length + (4 - input_length % 4) % 4
|
| 191 |
-
return padded // 4
|
| 192 |
-
|
| 193 |
-
def get_aux_loss(self) -> torch.Tensor:
|
| 194 |
-
"""Compute auxiliary losses: load balancing + z-loss."""
|
| 195 |
-
if self.last_router_logits is None:
|
| 196 |
-
return torch.tensor(0.0, device=self.conv[0].weight.device)
|
| 197 |
-
|
| 198 |
-
# Flatten for loss computation: (B, S, E) -> (B*S, E)
|
| 199 |
-
logits_flat = self.last_router_logits.view(-1, self.num_experts)
|
| 200 |
-
probs_flat = self.last_routing_weights.view(-1, self.num_experts)
|
| 201 |
-
|
| 202 |
-
balance = load_balancing_loss(probs_flat, self.num_experts, top_k=self.num_experts)
|
| 203 |
-
z = z_loss(logits_flat)
|
| 204 |
-
|
| 205 |
-
return self.aux_loss_coef * balance + self.z_loss_coef * z
|
| 206 |
|
| 207 |
|
| 208 |
# =============================================================================
|
| 209 |
-
# Shared
|
| 210 |
# =============================================================================
|
| 211 |
|
| 212 |
|
|
@@ -232,9 +176,9 @@ class SharedMoEBlock(nn.Module):
|
|
| 232 |
self.router = nn.Linear(input_dim, num_experts, bias=False)
|
| 233 |
nn.init.normal_(self.router.weight, mean=0.0, std=0.02)
|
| 234 |
|
| 235 |
-
self.shared_expert =
|
| 236 |
self.experts = nn.ModuleList(
|
| 237 |
-
[
|
| 238 |
)
|
| 239 |
|
| 240 |
self.last_router_logits = None
|
|
@@ -307,8 +251,8 @@ def z_loss(router_logits: torch.Tensor) -> torch.Tensor:
|
|
| 307 |
return torch.logsumexp(router_logits.float(), dim=-1).square().mean()
|
| 308 |
|
| 309 |
|
| 310 |
-
class
|
| 311 |
-
"""
|
| 312 |
|
| 313 |
def __init__(self, config):
|
| 314 |
super().__init__()
|
|
@@ -335,14 +279,12 @@ class SharedMoEAudioProjector(nn.Module):
|
|
| 335 |
|
| 336 |
def _init_weights(self):
|
| 337 |
with torch.no_grad():
|
| 338 |
-
nn.init.orthogonal_(self.moe.shared_expert.
|
| 339 |
-
nn.init.orthogonal_(self.moe.shared_expert.
|
| 340 |
-
nn.init.orthogonal_(self.moe.shared_expert.down_proj.weight, gain=0.5)
|
| 341 |
|
| 342 |
for expert in self.moe.experts:
|
| 343 |
-
nn.init.orthogonal_(expert.
|
| 344 |
-
nn.init.orthogonal_(expert.
|
| 345 |
-
nn.init.orthogonal_(expert.down_proj.weight, gain=0.01)
|
| 346 |
|
| 347 |
def get_output_length(self, input_length: int) -> int:
|
| 348 |
"""Calculate output sequence length given input length."""
|
|
@@ -354,7 +296,7 @@ class SharedMoEAudioProjector(nn.Module):
|
|
| 354 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 355 |
batch_size, seq_len, dim = x.size()
|
| 356 |
|
| 357 |
-
target_dtype = self.moe.shared_expert.
|
| 358 |
if x.dtype != target_dtype:
|
| 359 |
x = x.to(target_dtype)
|
| 360 |
|
|
@@ -503,6 +445,6 @@ class QFormerAudioProjector(nn.Module):
|
|
| 503 |
PROJECTOR_CLASSES = {
|
| 504 |
"mlp": MLPAudioProjector,
|
| 505 |
"mosa": MOSAProjector,
|
| 506 |
-
"
|
| 507 |
"qformer": QFormerAudioProjector,
|
| 508 |
}
|
|
|
|
| 89 |
|
| 90 |
|
| 91 |
class MOSAProjector(nn.Module):
|
| 92 |
+
"""MOSA-Base projector: simple 2-layer router with 4 simple adapters.
|
| 93 |
+
|
| 94 |
+
Based on "MOSA: Mixtures of Simple Adapters" (arXiv:2508.18998).
|
| 95 |
+
Uses softmax gating over all experts (dense MoE) with only cross-entropy loss.
|
| 96 |
+
Uses frame-stacking for downsampling (like MLP projector).
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
def __init__(self, config):
|
| 100 |
super().__init__()
|
| 101 |
self.encoder_dim = getattr(config, "encoder_dim", None) or 1280
|
| 102 |
self.llm_dim = getattr(config, "llm_dim", None) or 2048
|
| 103 |
+
self.k = getattr(config, "projector_pool_stride", 4)
|
| 104 |
+
self.num_experts = getattr(config, "num_experts", None) or 4 # MOSA-Base uses 4
|
| 105 |
adapter_hidden = getattr(config, "adapter_hidden_dim", None) or 4096
|
| 106 |
|
| 107 |
+
# Frame stacking: concat k adjacent frames then project
|
| 108 |
+
in_dim = self.encoder_dim * self.k
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
+
# --- 1. Simple Router (MOSA-Base: 2 layers with ReLU) ---
|
| 111 |
+
# Maps encoder_dim -> 512 -> num_experts
|
| 112 |
+
router_hidden = getattr(config, "router_hidden_dim", None) or 512
|
| 113 |
self.router = nn.Sequential(
|
| 114 |
+
nn.Linear(self.encoder_dim, router_hidden),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
nn.ReLU(),
|
| 116 |
+
nn.Linear(router_hidden, self.num_experts),
|
|
|
|
|
|
|
| 117 |
)
|
| 118 |
|
| 119 |
+
# --- 2. Experts (Simple 2-layer ReLU adapters per MOSA paper) ---
|
| 120 |
+
# Each expert: in_dim (stacked frames) -> hidden -> llm_dim
|
| 121 |
self.experts = nn.ModuleList(
|
| 122 |
+
[SimpleAdapter(in_dim, adapter_hidden, self.llm_dim) for _ in range(self.num_experts)]
|
|
|
|
|
|
|
|
|
|
| 123 |
)
|
| 124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
def forward(self, x):
|
| 126 |
+
# x: (B, S, encoder_dim)
|
| 127 |
+
batch_size, seq_len, dim = x.shape
|
| 128 |
|
| 129 |
+
# --- 1. Router Branch ---
|
| 130 |
+
# Mean pool encoder outputs for routing decisions
|
| 131 |
+
x_pooled = x.reshape(batch_size, -1, self.k, self.encoder_dim).mean(dim=2) # (B, S//k, D)
|
| 132 |
|
| 133 |
+
# Router logits and softmax gating (dense MoE)
|
| 134 |
+
routing_weights = F.softmax(self.router(x_pooled), dim=-1) # (B, S//k, num_experts)
|
|
|
|
| 135 |
|
| 136 |
+
# --- 2. Frame stacking for experts ---
|
| 137 |
+
# Reshape to combine k frames: [B, S, D] -> [B, S//k, D*k]
|
| 138 |
+
x_stacked = x.reshape(batch_size, -1, dim * self.k)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
# --- 3. Expert Mixture (Dense Execution) ---
|
| 141 |
+
# Run all experts and compute weighted sum
|
| 142 |
+
expert_outputs = torch.stack(
|
| 143 |
+
[expert(x_stacked) for expert in self.experts]
|
| 144 |
+
) # (E, B, S//k, D)
|
| 145 |
+
return torch.einsum("ebsd, bse -> bsd", expert_outputs, routing_weights)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
def get_output_length(self, input_length: int) -> int:
|
| 148 |
"""Calculate output sequence length given input length."""
|
| 149 |
+
return input_length // self.k
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
|
| 152 |
# =============================================================================
|
| 153 |
+
# MoE Projector (Shared Expert + Sparse Routed Experts)
|
| 154 |
# =============================================================================
|
| 155 |
|
| 156 |
|
|
|
|
| 176 |
self.router = nn.Linear(input_dim, num_experts, bias=False)
|
| 177 |
nn.init.normal_(self.router.weight, mean=0.0, std=0.02)
|
| 178 |
|
| 179 |
+
self.shared_expert = SimpleAdapter(input_dim, hidden_dim, output_dim)
|
| 180 |
self.experts = nn.ModuleList(
|
| 181 |
+
[SimpleAdapter(input_dim, hidden_dim, output_dim) for _ in range(num_experts)]
|
| 182 |
)
|
| 183 |
|
| 184 |
self.last_router_logits = None
|
|
|
|
| 251 |
return torch.logsumexp(router_logits.float(), dim=-1).square().mean()
|
| 252 |
|
| 253 |
|
| 254 |
+
class MoEAudioProjector(nn.Module):
|
| 255 |
+
"""MoE projector with shared expert + sparse routed experts."""
|
| 256 |
|
| 257 |
def __init__(self, config):
|
| 258 |
super().__init__()
|
|
|
|
| 279 |
|
| 280 |
def _init_weights(self):
|
| 281 |
with torch.no_grad():
|
| 282 |
+
nn.init.orthogonal_(self.moe.shared_expert.fc1.weight)
|
| 283 |
+
nn.init.orthogonal_(self.moe.shared_expert.fc2.weight, gain=0.5)
|
|
|
|
| 284 |
|
| 285 |
for expert in self.moe.experts:
|
| 286 |
+
nn.init.orthogonal_(expert.fc1.weight)
|
| 287 |
+
nn.init.orthogonal_(expert.fc2.weight, gain=0.01)
|
|
|
|
| 288 |
|
| 289 |
def get_output_length(self, input_length: int) -> int:
|
| 290 |
"""Calculate output sequence length given input length."""
|
|
|
|
| 296 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 297 |
batch_size, seq_len, dim = x.size()
|
| 298 |
|
| 299 |
+
target_dtype = self.moe.shared_expert.fc1.weight.dtype
|
| 300 |
if x.dtype != target_dtype:
|
| 301 |
x = x.to(target_dtype)
|
| 302 |
|
|
|
|
| 445 |
PROJECTOR_CLASSES = {
|
| 446 |
"mlp": MLPAudioProjector,
|
| 447 |
"mosa": MOSAProjector,
|
| 448 |
+
"moe": MoEAudioProjector,
|
| 449 |
"qformer": QFormerAudioProjector,
|
| 450 |
}
|