mazesmazes commited on
Commit
bd16460
·
verified ·
1 Parent(s): c26201f

Training in progress - step 1000

Browse files
Files changed (5) hide show
  1. asr_modeling.py +4 -1
  2. chat_template.jinja +94 -89
  3. projectors.py +155 -136
  4. tokenizer.json +2 -2
  5. tokenizer_config.json +9 -7
asr_modeling.py CHANGED
@@ -419,7 +419,10 @@ class ASRModel(PreTrainedModel, GenerationMixin):
419
  # Compute per-sample encoder output lengths using conv formulas
420
  encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
421
  token_counts = torch.tensor(
422
- [self.projector.get_output_length(int(length.item())) for length in encoder_lengths],
 
 
 
423
  device=audio_embeds.device,
424
  )
425
 
 
419
  # Compute per-sample encoder output lengths using conv formulas
420
  encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
421
  token_counts = torch.tensor(
422
+ [
423
+ self.projector.get_output_length(int(length.item()))
424
+ for length in encoder_lengths
425
+ ],
426
  device=audio_embeds.device,
427
  )
428
 
chat_template.jinja CHANGED
@@ -1,89 +1,94 @@
1
- {%- if tools %}
2
- {{- '<|im_start|>system\n' }}
3
- {%- if messages[0].role == 'system' %}
4
- {{- messages[0].content + '\n\n' }}
5
- {%- endif %}
6
- {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
7
- {%- for tool in tools %}
8
- {{- "\n" }}
9
- {{- tool | tojson }}
10
- {%- endfor %}
11
- {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
12
- {%- else %}
13
- {%- if messages[0].role == 'system' %}
14
- {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
15
- {%- endif %}
16
- {%- endif %}
17
- {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
18
- {%- for message in messages[::-1] %}
19
- {%- set index = (messages|length - 1) - loop.index0 %}
20
- {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
21
- {%- set ns.multi_step_tool = false %}
22
- {%- set ns.last_query_index = index %}
23
- {%- endif %}
24
- {%- endfor %}
25
- {%- for message in messages %}
26
- {%- if message.content is string %}
27
- {%- set content = message.content %}
28
- {%- else %}
29
- {%- set content = '' %}
30
- {%- endif %}
31
- {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
32
- {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
33
- {%- elif message.role == "assistant" %}
34
- {%- set reasoning_content = '' %}
35
- {%- if message.reasoning_content is string %}
36
- {%- set reasoning_content = message.reasoning_content %}
37
- {%- else %}
38
- {%- if '</think>' in content %}
39
- {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
40
- {%- set content = content.split('</think>')[-1].lstrip('\n') %}
41
- {%- endif %}
42
- {%- endif %}
43
- {%- if loop.index0 > ns.last_query_index %}
44
- {%- if loop.last or (not loop.last and reasoning_content) %}
45
- {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
46
- {%- else %}
47
- {{- '<|im_start|>' + message.role + '\n' + content }}
48
- {%- endif %}
49
- {%- else %}
50
- {{- '<|im_start|>' + message.role + '\n' + content }}
51
- {%- endif %}
52
- {%- if message.tool_calls %}
53
- {%- for tool_call in message.tool_calls %}
54
- {%- if (loop.first and content) or (not loop.first) %}
55
- {{- '\n' }}
56
- {%- endif %}
57
- {%- if tool_call.function %}
58
- {%- set tool_call = tool_call.function %}
59
- {%- endif %}
60
- {{- '<tool_call>\n{"name": "' }}
61
- {{- tool_call.name }}
62
- {{- '", "arguments": ' }}
63
- {%- if tool_call.arguments is string %}
64
- {{- tool_call.arguments }}
65
- {%- else %}
66
- {{- tool_call.arguments | tojson }}
67
- {%- endif %}
68
- {{- '}\n</tool_call>' }}
69
- {%- endfor %}
70
- {%- endif %}
71
- {{- '<|im_end|>\n' }}
72
- {%- elif message.role == "tool" %}
73
- {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
74
- {{- '<|im_start|>user' }}
75
- {%- endif %}
76
- {{- '\n<tool_response>\n' }}
77
- {{- content }}
78
- {{- '\n</tool_response>' }}
79
- {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
80
- {{- '<|im_end|>\n' }}
81
- {%- endif %}
82
- {%- endif %}
83
- {%- endfor %}
84
- {%- if add_generation_prompt %}
85
- {{- '<|im_start|>assistant\n' }}
86
- {%- if true %}
87
- {{- '<think>\n\n</think>\n\n' }}
88
- {%- endif %}
89
- {%- endif %}
 
 
 
 
 
 
1
+ {# ───── defaults ───── #}
2
+ {%- if enable_thinking is not defined -%}
3
+ {%- set enable_thinking = true -%}
4
+ {%- endif -%}
5
+
6
+ {# ───── reasoning mode ───── #}
7
+ {%- if enable_thinking -%}
8
+ {%- set reasoning_mode = "/think" -%}
9
+ {%- else -%}
10
+ {%- set reasoning_mode = "/no_think" -%}
11
+ {%- endif -%}
12
+
13
+ {# ───── header (system message) ───── #}
14
+ {{- "<|im_start|>system\n" -}}
15
+
16
+ {%- if messages[0].role == "system" -%}
17
+ {%- set system_message = messages[0].content -%}
18
+ {%- if "/no_think" in system_message -%}
19
+ {%- set reasoning_mode = "/no_think" -%}
20
+ {%- elif "/think" in system_message -%}
21
+ {%- set reasoning_mode = "/think" -%}
22
+ {%- endif -%}
23
+ {%- set custom_instructions = system_message.replace("/no_think", "").replace("/think", "").rstrip() -%}
24
+ {%- endif -%}
25
+
26
+ {%- if "/system_override" in system_message -%}
27
+ {{- custom_instructions.replace("/system_override", "").rstrip() -}}
28
+ {{- "<|im_end|>\n" -}}
29
+ {%- else -%}
30
+ {{- "## Metadata\n\n" -}}
31
+ {{- "Knowledge Cutoff Date: June 2025\n" -}}
32
+ {%- set today = strftime_now("%d %B %Y") -%}
33
+ {{- "Today Date: " ~ today ~ "\n" -}}
34
+ {{- "Reasoning Mode: " + reasoning_mode + "\n\n" -}}
35
+
36
+ {{- "## Custom Instructions\n\n" -}}
37
+ {%- if custom_instructions -%}
38
+ {{- custom_instructions + "\n\n" -}}
39
+ {%- elif reasoning_mode == "/think" -%}
40
+ {{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face. Your role as an assistant involves thoroughly exploring questions through a systematic thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracking, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution using the specified format: <think> Thought section </think> Solution section. In the Thought section, detail your reasoning process in steps. Each step should include detailed considerations such as analysing questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The Solution section should be logical, accurate, and concise and detail necessary steps needed to reach the conclusion.\n\n" -}}
41
+ {%- else -%}
42
+ {{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\n" -}}
43
+ {%- endif -%}
44
+
45
+ {%- if xml_tools or python_tools or tools -%}
46
+ {{- "### Tools\n\n" -}}
47
+ {%- if xml_tools or tools -%}
48
+ {%- if tools -%}
49
+ {%- set xml_tools = tools -%}
50
+ {%- endif -%}
51
+ {%- set ns = namespace(xml_tool_string="You may call one or more functions to assist with the user query.\nYou are provided with function signatures within <tools></tools> XML tags:\n\n<tools>\n") -%}
52
+ {%- for tool in xml_tools[:] -%} {# The slicing makes sure that xml_tools is a list #}
53
+ {%- set ns.xml_tool_string = ns.xml_tool_string ~ (tool | string) ~ "\n" -%}
54
+ {%- endfor -%}
55
+ {%- set xml_tool_string = ns.xml_tool_string + "</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>" -%}
56
+ {{- xml_tool_string -}}
57
+ {%- endif -%}
58
+ {%- if python_tools -%}
59
+ {%- set ns = namespace(python_tool_string="When you send a message containing Python code between '<code>' and '</code>' tags, it will be executed in a stateful Jupyter notebook environment, and you will then be given the output to continued reasoning in an agentic loop.\n\nYou can use the following tools in your python code like regular functions:\n<tools>\n") -%}
60
+ {%- for tool in python_tools[:] -%} {# The slicing makes sure that python_tools is a list #}
61
+ {%- set ns.python_tool_string = ns.python_tool_string ~ (tool | string) ~ "\n" -%}
62
+ {%- endfor -%}
63
+ {%- set python_tool_string = ns.python_tool_string + "</tools>\n\nThe state persists between code executions: so variables that you define in one step are still available thereafter." -%}
64
+ {{- python_tool_string -}}
65
+ {%- endif -%}
66
+ {{- "\n\n" -}}
67
+ {{- "<|im_end|>\n" -}}
68
+ {%- endif -%}
69
+ {%- endif -%}
70
+ {# ───── main loop ───── #}
71
+ {%- for message in messages -%}
72
+ {%- set content = message.content if message.content is string else "" -%}
73
+ {%- if message.role == "user" -%}
74
+ {{ "<|im_start|>" + message.role + "\n" + content + "<|im_end|>\n" }}
75
+ {%- elif message.role == "assistant" -%}
76
+ {% generation %}
77
+ {%- if reasoning_mode == "/think" -%}
78
+ {{ "<|im_start|>assistant\n" + content.lstrip("\n") + "<|im_end|>\n" }}
79
+ {%- else -%}
80
+ {{ "<|im_start|>assistant\n" + "<think>\n\n</think>\n" + content.lstrip("\n") + "<|im_end|>\n" }}
81
+ {%- endif -%}
82
+ {% endgeneration %}
83
+ {%- elif message.role == "tool" -%}
84
+ {{ "<|im_start|>" + "user\n" + content + "<|im_end|>\n" }}
85
+ {%- endif -%}
86
+ {%- endfor -%}
87
+ {# ───── generation prompt ───── #}
88
+ {%- if add_generation_prompt -%}
89
+ {%- if reasoning_mode == "/think" -%}
90
+ {{ "<|im_start|>assistant\n" }}
91
+ {%- else -%}
92
+ {{ "<|im_start|>assistant\n" + "<think>\n\n</think>\n" }}
93
+ {%- endif -%}
94
+ {%- endif -%}
projectors.py CHANGED
@@ -87,6 +87,34 @@ class SimpleAdapter(nn.Module):
87
  return self.fc2(self.act(self.fc1(x)))
88
 
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  class MOSAProjector(nn.Module):
91
  """MOSA-Base projector: simple 2-layer ReLU router with 4 simple adapters.
92
 
@@ -166,109 +194,18 @@ class MOSAProjector(nn.Module):
166
 
167
 
168
  # =============================================================================
169
- # MoE Projector (Shared Expert + Sparse Routed Experts)
170
  # =============================================================================
171
 
172
 
173
- class SharedMoEBlock(nn.Module):
174
- """MoE block with Shared + Sigmoid-Routed Experts."""
175
-
176
- def __init__(
177
- self,
178
- input_dim: int,
179
- hidden_dim: int,
180
- output_dim: int,
181
- num_experts: int = 4,
182
- top_k: int = 2,
183
- ):
184
- super().__init__()
185
- self.num_experts = num_experts
186
- self.top_k = top_k
187
- self.output_dim = output_dim
188
-
189
- # RMSNorm before routing
190
- self.norm = LlamaRMSNorm(input_dim, eps=1e-8)
191
-
192
- self.router = nn.Linear(input_dim, num_experts, bias=False)
193
- nn.init.normal_(self.router.weight, mean=0.0, std=0.02)
194
-
195
- self.shared_expert = SimpleAdapter(input_dim, hidden_dim, output_dim)
196
- self.experts = nn.ModuleList(
197
- [SimpleAdapter(input_dim, hidden_dim, output_dim) for _ in range(num_experts)]
198
- )
199
-
200
- self.last_router_logits = None
201
- self.last_router_probs = None
202
-
203
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
204
- batch_size, seq_len, dim = hidden_states.shape
205
-
206
- # 1. Apply Shared Expert
207
- normed_states = self.norm(hidden_states)
208
- shared_out = self.shared_expert(normed_states)
209
-
210
- # 2. Router Logic (Sigmoid Style)
211
- flat_hidden = normed_states.view(-1, dim)
212
- router_logits = self.router(flat_hidden)
213
-
214
- # Sigmoid routing
215
- router_probs = torch.sigmoid(router_logits)
216
-
217
- self.last_router_logits = router_logits
218
- self.last_router_probs = router_probs
219
-
220
- # 3. Top-K Selection
221
- top_k_scores, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
222
-
223
- # Normalize weights
224
- top_k_weights = top_k_scores / (top_k_scores.sum(dim=-1, keepdim=True) + 1e-6)
225
- top_k_weights = top_k_weights.to(hidden_states.dtype)
226
-
227
- # 4. Dispatch
228
- routed_out = self._dispatch_experts(flat_hidden, top_k_indices, top_k_weights)
229
- routed_out = routed_out.view(batch_size, seq_len, -1)
230
-
231
- return shared_out + routed_out
232
-
233
- def _dispatch_experts(
234
- self,
235
- hidden_states: torch.Tensor,
236
- top_k_indices: torch.Tensor,
237
- top_k_weights: torch.Tensor,
238
- ) -> torch.Tensor:
239
- num_tokens = hidden_states.shape[0]
240
- output = torch.zeros(
241
- num_tokens, self.output_dim, device=hidden_states.device, dtype=hidden_states.dtype
242
- )
243
-
244
- for expert_idx, expert in enumerate(self.experts):
245
- expert_mask = top_k_indices == expert_idx
246
- if not expert_mask.any():
247
- continue
248
-
249
- token_indices, slot_indices = torch.where(expert_mask)
250
- expert_input = hidden_states[token_indices]
251
- expert_output = expert(expert_input).to(output.dtype)
252
- weights = top_k_weights[token_indices, slot_indices].unsqueeze(-1)
253
- output.index_add_(0, token_indices, expert_output * weights)
254
-
255
- return output
256
-
257
-
258
- def load_balancing_loss(router_probs: torch.Tensor, num_experts: int, top_k: int) -> torch.Tensor:
259
- """Auxiliary loss to encourage balanced expert usage."""
260
- prob_per_expert = router_probs.mean(dim=0)
261
- target_mean = prob_per_expert.mean()
262
- return (prob_per_expert - target_mean).square().sum() * num_experts
263
-
264
-
265
- def z_loss(router_logits: torch.Tensor) -> torch.Tensor:
266
- """Z-loss to prevent router logits from growing too large."""
267
- return torch.logsumexp(router_logits.float(), dim=-1).square().mean()
268
 
 
 
269
 
270
- class MoEAudioProjector(nn.Module):
271
- """MoE projector with shared expert + sparse routed experts."""
272
 
273
  def __init__(self, config):
274
  """Initialize MoE projector.
@@ -279,40 +216,59 @@ class MoEAudioProjector(nn.Module):
279
  super().__init__()
280
 
281
  self.k = getattr(config, "projector_pool_stride", 4)
282
- encoder_dim = config.encoder_dim
283
 
284
- # Depthwise Conv for temporal mixing
285
- self.temporal_conv = nn.Conv1d(
286
- encoder_dim, encoder_dim, kernel_size=3, padding=1, groups=encoder_dim
287
- )
 
 
 
288
 
289
- in_dim = encoder_dim * self.k
290
  out_dim = config.llm_dim
291
- hidden_dim = getattr(config, "projector_hidden_dim", None) or in_dim
292
 
 
 
 
 
293
  self.num_experts = getattr(config, "num_experts", 4)
294
  self.top_k = getattr(config, "num_experts_per_tok", 2)
295
- self.aux_loss_coef = getattr(config, "router_aux_loss_coef", 0.02)
296
- self.z_loss_coef = getattr(config, "router_z_loss_coef", 0.001)
297
 
298
- self.moe = SharedMoEBlock(in_dim, hidden_dim, out_dim, self.num_experts, self.top_k)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  self._init_weights()
300
 
 
 
301
  def _init_weights(self):
 
302
  with torch.no_grad():
303
- nn.init.orthogonal_(self.moe.shared_expert.fc1.weight)
304
- nn.init.orthogonal_(self.moe.shared_expert.fc2.weight, gain=0.5)
305
 
306
- for expert in self.moe.experts:
307
- nn.init.orthogonal_(expert.fc1.weight)
308
- nn.init.orthogonal_(expert.fc2.weight, gain=0.01)
 
309
 
310
  def get_output_length(self, input_length: int) -> int:
311
- """Calculate output sequence length given input length."""
312
- # Temporal pooling with stride k
313
- if input_length % self.k:
314
- input_length += self.k - input_length % self.k
315
- return input_length // self.k
316
 
317
  def forward(self, x: torch.Tensor) -> torch.Tensor:
318
  """Project audio features using shared + sparse MoE.
@@ -323,32 +279,95 @@ class MoEAudioProjector(nn.Module):
323
  Returns:
324
  Projected features of shape [batch, out_len, llm_dim]
325
  """
326
- batch_size, seq_len, dim = x.size()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
 
328
- target_dtype = self.moe.shared_expert.fc1.weight.dtype
329
- if x.dtype != target_dtype:
330
- x = x.to(target_dtype)
 
 
 
 
331
 
332
- # Temporal Context Injection
333
- x_ctx = x.transpose(1, 2)
334
- x_ctx = self.temporal_conv(x_ctx)
335
- x = x + x_ctx.transpose(1, 2)
336
 
337
- if seq_len % self.k:
338
- x = F.pad(x, (0, 0, 0, self.k - seq_len % self.k))
339
 
340
- x = x.view(batch_size, -1, dim * self.k)
 
341
 
342
- return self.moe(x)
 
343
 
344
- def get_aux_loss(self) -> torch.Tensor:
345
- if self.moe.last_router_logits is None:
346
- return torch.tensor(0.0, device=self.moe.router.weight.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
 
348
- balance = load_balancing_loss(self.moe.last_router_probs, self.num_experts, self.top_k)
349
- z = z_loss(self.moe.last_router_logits)
 
350
 
351
- return self.aux_loss_coef * balance + self.z_loss_coef * z
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
 
353
 
354
  # =============================================================================
 
87
  return self.fc2(self.act(self.fc1(x)))
88
 
89
 
90
+ class SwiGLU(nn.Module):
91
+ """SwiGLU activation with gated linear units (used in LLaMA, Mistral, etc.)."""
92
+
93
+ def __init__(self, dim: int, hidden_dim: int, bias: bool = False):
94
+ super().__init__()
95
+ self.w1 = nn.Linear(dim, hidden_dim, bias=bias) # Gate
96
+ self.w2 = nn.Linear(dim, hidden_dim, bias=bias) # Value
97
+ self.w3 = nn.Linear(hidden_dim, dim, bias=bias) # Output
98
+
99
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
100
+ return self.w3(F.silu(self.w1(x)) * self.w2(x))
101
+
102
+
103
+ class AsymmetricSwiGLU(nn.Module):
104
+ """SwiGLU that handles different input and output dimensions."""
105
+
106
+ def __init__(
107
+ self, in_features: int, hidden_features: int, out_features: int, bias: bool = False
108
+ ):
109
+ super().__init__()
110
+ self.w1 = nn.Linear(in_features, hidden_features, bias=bias) # Gate
111
+ self.w2 = nn.Linear(in_features, hidden_features, bias=bias) # Value
112
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias) # Output
113
+
114
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
115
+ return self.w3(F.silu(self.w1(x)) * self.w2(x))
116
+
117
+
118
  class MOSAProjector(nn.Module):
119
  """MOSA-Base projector: simple 2-layer ReLU router with 4 simple adapters.
120
 
 
194
 
195
 
196
  # =============================================================================
197
+ # MoE Projector (Pure PyTorch with Shared Expert)
198
  # =============================================================================
199
 
200
 
201
+ class MoEAudioProjector(nn.Module):
202
+ """MoE projector with shared expert (DeepSeek-style), pure PyTorch implementation.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
+ Uses 4 sparse experts with top-2 routing plus a shared expert that processes all tokens.
205
+ No external dependencies (megablocks removed).
206
 
207
+ Architecture matches main branch: norm → experts(in_dim → hidden → out_dim)
208
+ """
209
 
210
  def __init__(self, config):
211
  """Initialize MoE projector.
 
216
  super().__init__()
217
 
218
  self.k = getattr(config, "projector_pool_stride", 4)
219
+ self.aux_coef = getattr(config, "router_aux_loss_coef", 0.01)
220
 
221
+ # Stability coefficients
222
+ self.router_z_loss_coef = getattr(
223
+ config, "router_z_loss_coef", 1e-4
224
+ ) # Prevents logit explosion
225
+ self.router_jitter_noise = getattr(
226
+ config, "router_jitter_noise", 0.01
227
+ ) # Prevents expert collapse
228
 
229
+ in_dim = config.encoder_dim * self.k
230
  out_dim = config.llm_dim
 
231
 
232
+ # Expert hidden dim (default = output dim)
233
+ hidden_dim = getattr(config, "projector_hidden_dim", None) or out_dim
234
+
235
+ # Number of experts and top-k selection
236
  self.num_experts = getattr(config, "num_experts", 4)
237
  self.top_k = getattr(config, "num_experts_per_tok", 2)
 
 
238
 
239
+ # A. Normalize stacked input (like main branch SharedMoEBlock)
240
+ self.norm = LlamaRMSNorm(in_dim, eps=1e-6)
241
+
242
+ # B. Router (operates on stacked input)
243
+ self.router = nn.Linear(in_dim, self.num_experts, bias=False)
244
+
245
+ # C. Experts: simple 2-layer MLP (same as MLPAudioProjector)
246
+ self.experts = nn.ModuleList(
247
+ [SimpleAdapter(in_dim, hidden_dim, out_dim) for _ in range(self.num_experts)]
248
+ )
249
+
250
+ # D. Shared Expert (same architecture)
251
+ self.shared_expert = SimpleAdapter(in_dim, hidden_dim, out_dim)
252
+
253
+ # E. Initialize weights for stable training
254
  self._init_weights()
255
 
256
+ self.last_aux_loss = torch.tensor(0.0)
257
+
258
  def _init_weights(self):
259
+ """Initialize weights for stable training start."""
260
  with torch.no_grad():
261
+ # Router: small weights -> uniform probability
262
+ nn.init.normal_(self.router.weight, mean=0.0, std=0.02)
263
 
264
+ # Experts: xavier for fc1, small for fc2 (output)
265
+ for expert in [self.shared_expert, *self.experts]:
266
+ nn.init.xavier_uniform_(expert.fc1.weight)
267
+ nn.init.normal_(expert.fc2.weight, mean=0.0, std=0.01) # Small init
268
 
269
  def get_output_length(self, input_length: int) -> int:
270
+ """Calculate output sequence length given input length (matches MLP projector)."""
271
+ return (input_length - self.k) // self.k + 1
 
 
 
272
 
273
  def forward(self, x: torch.Tensor) -> torch.Tensor:
274
  """Project audio features using shared + sparse MoE.
 
279
  Returns:
280
  Projected features of shape [batch, out_len, llm_dim]
281
  """
282
+ # 1. Frame Stacking
283
+ batch, seq, dim = x.shape
284
+ out_len = (seq - self.k) // self.k + 1
285
+ x = x[:, : out_len * self.k, :]
286
+ x = x.reshape(batch, out_len, dim * self.k)
287
+
288
+ # 2. Normalize stacked input (like main branch SharedMoEBlock)
289
+ x = self.norm(x)
290
+ flat_x = x.view(-1, x.size(-1)) # [tokens, in_dim]
291
+
292
+ # 3. Shared Expert (compute first, creates output tensor)
293
+ output = self.shared_expert(flat_x)
294
+
295
+ # 4. Sparse Experts (in-place add to shared output)
296
+ self.last_aux_loss = self._forward_sparse(flat_x, output)
297
+
298
+ return output.view(batch, out_len, -1)
299
+
300
+ def _forward_sparse(self, x: torch.Tensor, output: torch.Tensor) -> torch.Tensor:
301
+ """Stability-hardened sparse expert dispatch (in-place add to output).
302
+
303
+ Args:
304
+ x: Flattened input of shape [tokens, dim]
305
+ output: Output tensor to add sparse expert results into (in-place)
306
+
307
+ Returns:
308
+ Auxiliary loss tensor
309
+ """
310
+ # A. Router Logic with Jitter
311
+ logits = self.router(x)
312
 
313
+ if self.training and self.router_jitter_noise > 0:
314
+ # Jitter: multiply by uniform noise (1-eps, 1+eps) to shake decision boundary
315
+ # Prevents router from getting stuck on one expert early in training
316
+ noise = torch.empty_like(logits).uniform_(
317
+ 1.0 - self.router_jitter_noise, 1.0 + self.router_jitter_noise
318
+ )
319
+ logits = logits * noise
320
 
321
+ # Force float32 for softmax (bf16/fp16 exponentials can overflow)
322
+ probs = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(x)
 
 
323
 
324
+ # B. Top-K Selection
325
+ top_k_weights, top_k_indices = torch.topk(probs, self.top_k, dim=-1)
326
 
327
+ # Normalize weights so they sum to 1.0
328
+ top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-6)
329
 
330
+ # C. Aux Loss + Z-Loss
331
+ aux_loss = torch.tensor(0.0, device=x.device)
332
 
333
+ if self.training:
334
+ # Load balancing loss (batch-size invariant)
335
+ prob_per_expert = probs.mean(0) # [num_experts]
336
+ target = 1.0 / self.num_experts
337
+ balance_loss = (
338
+ self.aux_coef * ((prob_per_expert - target) ** 2).mean() * self.num_experts
339
+ )
340
+
341
+ # Z-loss: penalty on large logits to prevent softmax saturation
342
+ z_loss = self.router_z_loss_coef * torch.logsumexp(logits, dim=-1).pow(2).mean()
343
+
344
+ aux_loss = balance_loss + z_loss
345
+
346
+ # D. Dispatch Loop (in-place add to output)
347
+ for i, expert in enumerate(self.experts):
348
+ # Create boolean mask for tokens that selected Expert 'i'
349
+ mask = top_k_indices == i
350
 
351
+ if mask.any():
352
+ # token_idx = which tokens, k_idx = 1st or 2nd choice
353
+ token_idx, k_idx = torch.where(mask)
354
 
355
+ # Gather inputs and compute
356
+ expert_input = x[token_idx]
357
+ expert_output = expert(expert_input)
358
+
359
+ # Apply routing weight
360
+ weight = top_k_weights[token_idx, k_idx].unsqueeze(-1)
361
+ weighted_output = (expert_output * weight).type_as(output)
362
+
363
+ # Scatter back in-place (index_add_ is atomic and deterministic)
364
+ output.index_add_(0, token_idx, weighted_output)
365
+
366
+ return aux_loss
367
+
368
+ def get_aux_loss(self) -> torch.Tensor:
369
+ """Return auxiliary load balancing loss."""
370
+ return self.last_aux_loss
371
 
372
 
373
  # =============================================================================
tokenizer.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:33b674fb8444e2553eae8f1b261093371920a28ef75b5c18f4deb3f9217ed0ba
3
- size 11422834
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4aeaf198f783cbf58d8cd59812baac429ffe49147bf9648f6618de20b8d4a4c
3
+ size 17209003
tokenizer_config.json CHANGED
@@ -1,17 +1,19 @@
1
  {
2
- "add_prefix_space": false,
3
  "backend": "tokenizers",
4
  "bos_token": null,
5
- "clean_up_tokenization_spaces": false,
6
  "eos_token": "<|im_end|>",
7
- "errors": "replace",
8
  "extra_special_tokens": [
9
  "<audio>"
10
  ],
 
11
  "is_local": false,
 
 
 
 
12
  "model_max_length": 131072,
13
- "pad_token": "<|endoftext|>",
14
- "split_special_tokens": false,
15
- "tokenizer_class": "Qwen2Tokenizer",
16
- "unk_token": null
17
  }
 
1
  {
 
2
  "backend": "tokenizers",
3
  "bos_token": null,
4
+ "clean_up_tokenization_spaces": true,
5
  "eos_token": "<|im_end|>",
 
6
  "extra_special_tokens": [
7
  "<audio>"
8
  ],
9
+ "fast": false,
10
  "is_local": false,
11
+ "model_input_names": [
12
+ "input_ids",
13
+ "attention_mask"
14
+ ],
15
  "model_max_length": 131072,
16
+ "model_specific_special_tokens": {},
17
+ "pad_token": "<|finetune_right_pad_id|>",
18
+ "tokenizer_class": "TokenizersBackend"
 
19
  }